Skip to content

Support falling back to OIDC metadata for auth #1061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 17, 2025
117 changes: 46 additions & 71 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,72 +251,32 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
except ValidationError:
pass

def _build_well_known_path(self, pathname: str) -> str:
"""Construct well-known path for OAuth metadata discovery."""
well_known_path = f"/.well-known/oauth-authorization-server{pathname}"
if pathname.endswith("/"):
# Strip trailing slash from pathname to avoid double slashes
well_known_path = well_known_path[:-1]
return well_known_path

def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool:
"""Determine if fallback to root discovery should be attempted."""
return response_status == 404 and pathname != "/"

async def _try_metadata_discovery(self, url: str) -> httpx.Request:
"""Build metadata discovery request for a specific URL."""
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _discover_oauth_metadata(self) -> httpx.Request:
"""Build OAuth metadata discovery request with fallback support."""
if self.context.auth_server_url:
auth_server_url = self.context.auth_server_url
else:
auth_server_url = self.context.server_url

# Per RFC 8414, try path-aware discovery first
def _get_discovery_urls(self) -> list[str]:
"""Generate ordered list of (url, type) tuples for discovery attempts."""
urls: list[str] = []
auth_server_url = self.context.auth_server_url or self.context.server_url
parsed = urlparse(auth_server_url)
well_known_path = self._build_well_known_path(parsed.path)
base_url = f"{parsed.scheme}://{parsed.netloc}"
url = urljoin(base_url, well_known_path)

# Store fallback info for use in response handler
self.context.discovery_base_url = base_url
self.context.discovery_pathname = parsed.path
# RFC 8414: Path-aware OAuth discovery
if parsed.path and parsed.path != "/":
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oauth_path))

return await self._try_metadata_discovery(url)
# OAuth root fallback
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))

async def _discover_oauth_metadata_fallback(self) -> httpx.Request:
"""Build fallback OAuth metadata discovery request for legacy servers."""
base_url = getattr(self.context, "discovery_base_url", "")
if not base_url:
raise OAuthFlowError("No base URL available for fallback discovery")

# Fallback to root discovery for legacy servers
url = urljoin(base_url, "/.well-known/oauth-authorization-server")
return await self._try_metadata_discovery(url)

async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool:
"""Handle OAuth metadata response. Returns True if handled successfully."""
if response.status_code == 200:
try:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata
# Apply default scope if none specified
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)
return True
except ValidationError:
pass
# RFC 8414 section 5: Path-aware OIDC discovery
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
if parsed.path and parsed.path != "/":
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oidc_path))

# Check if we should attempt fallback (404 on path-aware discovery)
if not is_fallback and self._should_attempt_fallback(
response.status_code, getattr(self.context, "discovery_pathname", "/")
):
return False # Signal that fallback should be attempted
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration"
urls.append(oidc_fallback)

return True # Signal no fallback needed (either success or non-404 error)
return urls

async def _register_client(self) -> httpx.Request | None:
"""Build registration request or skip if already registered."""
Expand Down Expand Up @@ -511,6 +471,17 @@ def _add_auth_header(self, request: httpx.Request) -> None:
if self.context.current_tokens and self.context.current_tokens.access_token:
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"

def _create_oauth_metadata_request(self, url: str) -> httpx.Request:
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata
# Apply default scope if needed
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
Expand Down Expand Up @@ -544,15 +515,19 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
oauth_response = yield oauth_request
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)

# If path-aware discovery failed with 404, try fallback to root
if not handled:
fallback_request = await self._discover_oauth_metadata_fallback()
fallback_response = yield fallback_request
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
discovery_urls = self._get_discovery_urls()
for url in discovery_urls:
request = self._create_oauth_metadata_request(url)
response = yield request

if response.status_code == 200:
try:
await self._handle_oauth_metadata_response(response)
break
except ValidationError:
continue
elif response.status_code != 404:
break # Non-404 error, stop trying

# Step 3: Register client if needed
registration_request = await self._register_client()
Expand All @@ -571,6 +546,6 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
logger.exception("OAuth flow error")
raise

# Retry with new tokens
self._add_auth_header(request)
yield request
# Retry with new tokens
self._add_auth_header(request)
yield request
146 changes: 15 additions & 131 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,107 +235,30 @@ async def callback_handler() -> tuple[str, str | None]:
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_discover_oauth_metadata_request(self, oauth_provider):
def test_create_oauth_metadata_request(self, oauth_provider):
"""Test OAuth metadata discovery request building."""
request = await oauth_provider._discover_oauth_metadata()
request = oauth_provider._create_oauth_metadata_request("https://example.com")

# Ensure correct method and headers, and that the URL is unmodified
assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_discover_oauth_metadata_request_no_path(self, client_metadata, mock_storage):
"""Test OAuth metadata discovery request building when server has no path."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

request = await provider._discover_oauth_metadata()

assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server"
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_discover_oauth_metadata_request_trailing_slash(self, client_metadata, mock_storage):
"""Test OAuth metadata discovery request building when server path has trailing slash."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp/",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

request = await provider._discover_oauth_metadata()

assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
assert str(request.url) == "https://example.com"
assert "mcp-protocol-version" in request.headers


class TestOAuthFallback:
"""Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers."""

@pytest.mark.anyio
async def test_fallback_discovery_request(self, client_metadata, mock_storage):
"""Test fallback discovery request building."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

# Set up discovery state manually as if path-aware discovery was attempted
provider.context.discovery_base_url = "https://api.example.com"
provider.context.discovery_pathname = "/v1/mcp"
async def test_oauth_discovery_fallback_order(self, oauth_provider):
"""Test fallback URL construction order."""
discovery_urls = oauth_provider._get_discovery_urls()

# Test fallback request building
request = await provider._discover_oauth_metadata_fallback()

assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server"
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_should_attempt_fallback(self, oauth_provider):
"""Test fallback decision logic."""
# Should attempt fallback on 404 with non-root path
assert oauth_provider._should_attempt_fallback(404, "/v1/mcp")

# Should NOT attempt fallback on 404 with root path
assert not oauth_provider._should_attempt_fallback(404, "/")

# Should NOT attempt fallback on other status codes
assert not oauth_provider._should_attempt_fallback(200, "/v1/mcp")
assert not oauth_provider._should_attempt_fallback(500, "/v1/mcp")
assert discovery_urls == [
"https://api.example.com/.well-known/oauth-authorization-server/v1/mcp",
"https://api.example.com/.well-known/oauth-authorization-server",
"https://api.example.com/.well-known/openid-configuration/v1/mcp",
"https://api.example.com/v1/mcp/.well-known/openid-configuration",
]

@pytest.mark.anyio
async def test_handle_metadata_response_success(self, oauth_provider):
Expand All @@ -348,50 +271,11 @@ async def test_handle_metadata_response_success(self, oauth_provider):
}"""
response = httpx.Response(200, content=content)

# Should return True (success) and set metadata
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
assert result is True
# Should set metadata
await oauth_provider._handle_oauth_metadata_response(response)
assert oauth_provider.context.oauth_metadata is not None
assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/"

@pytest.mark.anyio
async def test_handle_metadata_response_404_needs_fallback(self, oauth_provider):
"""Test 404 response handling that should trigger fallback."""
# Set up discovery state for non-root path
oauth_provider.context.discovery_base_url = "https://api.example.com"
oauth_provider.context.discovery_pathname = "/v1/mcp"

# Mock 404 response
response = httpx.Response(404)

# Should return False (needs fallback)
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
assert result is False

@pytest.mark.anyio
async def test_handle_metadata_response_404_no_fallback_needed(self, oauth_provider):
"""Test 404 response handling when no fallback is needed."""
# Set up discovery state for root path
oauth_provider.context.discovery_base_url = "https://api.example.com"
oauth_provider.context.discovery_pathname = "/"

# Mock 404 response
response = httpx.Response(404)

# Should return True (no fallback needed)
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
assert result is True

@pytest.mark.anyio
async def test_handle_metadata_response_404_fallback_attempt(self, oauth_provider):
"""Test 404 response handling during fallback attempt."""
# Mock 404 response during fallback
response = httpx.Response(404)

# Should return True (fallback attempt complete, no further action needed)
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=True)
assert result is True

@pytest.mark.anyio
async def test_register_client_request(self, oauth_provider):
"""Test client registration request building."""
Expand Down
39 changes: 39 additions & 0 deletions tests/shared/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Tests for OAuth 2.0 shared code."""

from mcp.shared.auth import OAuthMetadata


class TestOAuthMetadata:
"""Tests for OAuthMetadata parsing."""

def test_oauth(self):
"""Should not throw when parsing OAuth metadata."""
OAuthMetadata.model_validate(
{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/oauth2/authorize",
"token_endpoint": "https://example.com/oauth2/token",
"scopes_supported": ["read", "write"],
"response_types_supported": ["code", "token"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
}
)

def test_oidc(self):
"""Should not throw when parsing OIDC metadata."""
OAuthMetadata.model_validate(
{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/oauth2/authorize",
"token_endpoint": "https://example.com/oauth2/token",
"end_session_endpoint": "https://example.com/logout",
"id_token_signing_alg_values_supported": ["RS256"],
"jwks_uri": "https://example.com/.well-known/jwks.json",
"response_types_supported": ["code", "token"],
"revocation_endpoint": "https://example.com/oauth2/revoke",
"scopes_supported": ["openid", "read", "write"],
"subject_types_supported": ["public"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
"userinfo_endpoint": "https://example.com/oauth2/userInfo",
}
)
Loading