|
| 1 | +from unittest.mock import patch |
| 2 | +from urllib.parse import parse_qs, urlparse |
| 3 | + |
1 | 4 | import pytest
|
2 |
| -from django.contrib.auth import get_user |
| 5 | +from django.contrib.auth import get_user, get_user_model |
3 | 6 | from django.contrib.auth.models import AnonymousUser
|
4 | 7 | from django.test import RequestFactory
|
5 | 8 | from django.urls import reverse
|
|
12 | 15 | InvalidOIDCClientError,
|
13 | 16 | InvalidOIDCRedirectURIError,
|
14 | 17 | )
|
15 |
| -from oauth2_provider.models import get_access_token_model, get_id_token_model, get_refresh_token_model |
| 18 | +from oauth2_provider.models import ( |
| 19 | + get_access_token_model, |
| 20 | + get_application_model, |
| 21 | + get_id_token_model, |
| 22 | + get_refresh_token_model, |
| 23 | +) |
16 | 24 | from oauth2_provider.oauth2_validators import OAuth2Validator
|
17 | 25 | from oauth2_provider.settings import oauth2_settings
|
| 26 | +from oauth2_provider.views.base import AuthorizationView |
18 | 27 | from oauth2_provider.views.oidc import RPInitiatedLogoutView, _load_id_token, _validate_claims
|
19 | 28 |
|
20 | 29 | from . import presets
|
@@ -47,6 +56,7 @@ def test_get_connect_discovery_info(self):
|
47 | 56 | "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
48 | 57 | "code_challenge_methods_supported": ["plain", "S256"],
|
49 | 58 | "claims_supported": ["sub"],
|
| 59 | + "prompt_values_supported": ["none", "login"], |
50 | 60 | }
|
51 | 61 | response = self.client.get("/o/.well-known/openid-configuration")
|
52 | 62 | self.assertEqual(response.status_code, 200)
|
@@ -74,6 +84,7 @@ def test_get_connect_discovery_info_deprecated(self):
|
74 | 84 | "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
75 | 85 | "code_challenge_methods_supported": ["plain", "S256"],
|
76 | 86 | "claims_supported": ["sub"],
|
| 87 | + "prompt_values_supported": ["none", "login"], |
77 | 88 | }
|
78 | 89 | response = self.client.get("/o/.well-known/openid-configuration/")
|
79 | 90 | self.assertEqual(response.status_code, 200)
|
@@ -101,6 +112,7 @@ def expect_json_response_with_rp_logout(self, base):
|
101 | 112 | "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
102 | 113 | "code_challenge_methods_supported": ["plain", "S256"],
|
103 | 114 | "claims_supported": ["sub"],
|
| 115 | + "prompt_values_supported": ["none", "login"], |
104 | 116 | "end_session_endpoint": f"{base}/logout/",
|
105 | 117 | }
|
106 | 118 | response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info"))
|
@@ -135,6 +147,7 @@ def test_get_connect_discovery_info_without_issuer_url(self):
|
135 | 147 | "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
136 | 148 | "code_challenge_methods_supported": ["plain", "S256"],
|
137 | 149 | "claims_supported": ["sub"],
|
| 150 | + "prompt_values_supported": ["none", "login"], |
138 | 151 | }
|
139 | 152 | response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info"))
|
140 | 153 | self.assertEqual(response.status_code, 200)
|
@@ -206,6 +219,140 @@ def test_get_jwks_info_multiple_rsa_keys(self):
|
206 | 219 | assert response.json() == expected_response
|
207 | 220 |
|
208 | 221 |
|
| 222 | +@pytest.mark.usefixtures("oauth2_settings") |
| 223 | +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_REGISTRATION) |
| 224 | +class TestRPInitiatedRegistration(TestCase): |
| 225 | + def setUp(self): |
| 226 | + Application = get_application_model() |
| 227 | + self.application = Application.objects.create( |
| 228 | + name="Test Application", |
| 229 | + redirect_uris="http://localhost http://example.com", |
| 230 | + client_type=Application.CLIENT_CONFIDENTIAL, |
| 231 | + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, |
| 232 | + ) |
| 233 | + User = get_user_model() |
| 234 | + self. test_user = User. objects. create_user( "test_user", "[email protected]", "123456") |
| 235 | + |
| 236 | + def _build_authorization_request(self, query_params, user=None): |
| 237 | + auth_url = reverse("oauth2_provider:authorize") |
| 238 | + query_string = "&".join(f"{k}={v}" for k, v in query_params.items()) |
| 239 | + full_auth_url = f"{auth_url}?{query_string}" |
| 240 | + request = RequestFactory().get(full_auth_url) |
| 241 | + request.user = user or AnonymousUser() |
| 242 | + return request |
| 243 | + |
| 244 | + def test_connect_discovery_info_has_create(self): |
| 245 | + expected_response = { |
| 246 | + "issuer": "http://localhost/o", |
| 247 | + "authorization_endpoint": "http://localhost/o/authorize/", |
| 248 | + "token_endpoint": "http://localhost/o/token/", |
| 249 | + "userinfo_endpoint": "http://localhost/o/userinfo/", |
| 250 | + "jwks_uri": "http://localhost/o/.well-known/jwks.json", |
| 251 | + "scopes_supported": ["read", "write", "openid"], |
| 252 | + "response_types_supported": [ |
| 253 | + "code", |
| 254 | + "token", |
| 255 | + "id_token", |
| 256 | + "id_token token", |
| 257 | + "code token", |
| 258 | + "code id_token", |
| 259 | + "code id_token token", |
| 260 | + ], |
| 261 | + "subject_types_supported": ["public"], |
| 262 | + "id_token_signing_alg_values_supported": ["RS256", "HS256"], |
| 263 | + "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], |
| 264 | + "code_challenge_methods_supported": ["plain", "S256"], |
| 265 | + "claims_supported": ["sub"], |
| 266 | + "prompt_values_supported": ["none", "login", "create"], |
| 267 | + } |
| 268 | + response = self.client.get("/o/.well-known/openid-configuration") |
| 269 | + self.assertEqual(response.status_code, 200) |
| 270 | + assert response.json() == expected_response |
| 271 | + |
| 272 | + def test_prompt_create_redirects_to_registration_view(self): |
| 273 | + request = self._build_authorization_request( |
| 274 | + query_params={ |
| 275 | + "response_type": "code", |
| 276 | + "client_id": self.application.client_id, |
| 277 | + "redirect_uri": "http://localhost", |
| 278 | + "scope": "openid", |
| 279 | + "prompt": "create", |
| 280 | + } |
| 281 | + ) |
| 282 | + view = AuthorizationView() |
| 283 | + view.setup(request) |
| 284 | + |
| 285 | + with patch("oauth2_provider.views.base.reverse") as patched_reverse: |
| 286 | + patched_reverse.return_value = "/register-test/" |
| 287 | + response = view.get(request) |
| 288 | + |
| 289 | + self.assertEqual(response.status_code, 302) |
| 290 | + redirect_url = response.url |
| 291 | + parsed_url = urlparse(redirect_url) |
| 292 | + |
| 293 | + # Verify it's the registration URL |
| 294 | + self.assertEqual(parsed_url.path, "/register-test/") |
| 295 | + |
| 296 | + # Verify the query parameters |
| 297 | + query = parse_qs(parsed_url.query) |
| 298 | + self.assertIn("next", query) |
| 299 | + |
| 300 | + # Verify the next parameter doesn't contain prompt=create |
| 301 | + next_url = query["next"][0] |
| 302 | + self.assertNotIn("prompt=create", next_url) |
| 303 | + |
| 304 | + # But it should contain the other original parameters |
| 305 | + self.assertIn("response_type=code", next_url) |
| 306 | + self.assertIn(f"client_id={self.application.client_id}", next_url) |
| 307 | + |
| 308 | + def test_logged_users_can_not_prompt_create(self): |
| 309 | + view = AuthorizationView() |
| 310 | + request = self._build_authorization_request( |
| 311 | + query_params={ |
| 312 | + "response_type": "code", |
| 313 | + "client_id": self.application.client_id, |
| 314 | + "redirect_uri": "http://localhost", |
| 315 | + "scope": "openid", |
| 316 | + "prompt": "create", |
| 317 | + }, |
| 318 | + user=self.test_user, |
| 319 | + ) |
| 320 | + view.setup(request) |
| 321 | + response = view.get(request) |
| 322 | + self.assertEqual(response.status_code, 302) |
| 323 | + self.assertIn("account_selection_required", response.url) |
| 324 | + |
| 325 | + def test_state_is_echoed_on_bad_requests(self): |
| 326 | + state_query_params = { |
| 327 | + "response_type": "code", |
| 328 | + "client_id": self.application.client_id, |
| 329 | + "redirect_uri": "http://localhost", |
| 330 | + "scope": "openid", |
| 331 | + "prompt": "create", |
| 332 | + "state": "testing_state", |
| 333 | + } |
| 334 | + view = AuthorizationView() |
| 335 | + request = self._build_authorization_request(query_params=state_query_params, user=self.test_user) |
| 336 | + view.setup(request) |
| 337 | + response = view.get(request) |
| 338 | + self.assertIn("state=testing_state", response.url) |
| 339 | + |
| 340 | + def test_bad_request_if_missing_redirect_uri(self): |
| 341 | + view = AuthorizationView() |
| 342 | + request = self._build_authorization_request( |
| 343 | + query_params={ |
| 344 | + "response_type": "code", |
| 345 | + "client_id": self.application.client_id, |
| 346 | + "scope": "openid", |
| 347 | + "prompt": "create", |
| 348 | + }, |
| 349 | + user=self.test_user, |
| 350 | + ) |
| 351 | + view.setup(request) |
| 352 | + response = view.handle_prompt_create() |
| 353 | + self.assertEqual(response.status_code, 400) |
| 354 | + |
| 355 | + |
209 | 356 | def mock_request():
|
210 | 357 | """
|
211 | 358 | Dummy request with an AnonymousUser attached.
|
|
0 commit comments