|
22 | 22 | import requests
|
23 | 23 |
|
24 | 24 | from requests.auth import _basic_auth_str
|
| 25 | +from requests.exceptions import HTTPError |
25 | 26 |
|
26 | 27 |
|
27 | 28 | fake_time = time.time()
|
28 | 29 | CODE = "asdf345xdf"
|
29 | 30 |
|
30 | 31 |
|
31 |
| -def fake_token(token): |
| 32 | +def fake_token(token, status_code: int = 200): |
32 | 33 | def fake_send(r, **kwargs):
|
33 | 34 | resp = mock.MagicMock()
|
34 |
| - resp.status_code = 200 |
| 35 | + resp.status_code = status_code |
35 | 36 | resp.text = json.dumps(token)
|
36 | 37 | return resp
|
37 | 38 |
|
@@ -133,11 +134,11 @@ def test_refresh_token_request(self):
|
133 | 134 | self.expired_token["expires_in"] = "-1"
|
134 | 135 | del self.expired_token["expires_at"]
|
135 | 136 |
|
136 |
| - def fake_refresh(r, **kwargs): |
| 137 | + def fake_refresh(r, status_code: int = 200, **kwargs): |
137 | 138 | if "/refresh" in r.url:
|
138 | 139 | self.assertNotIn("Authorization", r.headers)
|
139 | 140 | resp = mock.MagicMock()
|
140 |
| - resp.status_code = 200 |
| 141 | + resp.status_code = status_code |
141 | 142 | resp.text = json.dumps(self.token)
|
142 | 143 | return resp
|
143 | 144 |
|
@@ -170,6 +171,19 @@ def token_updater(token):
|
170 | 171 | sess.send = fake_refresh
|
171 | 172 | sess.get("https://i.b")
|
172 | 173 |
|
| 174 | + # test 5xx error handler |
| 175 | + for client in self.clients: |
| 176 | + sess = OAuth2Session( |
| 177 | + client=client, |
| 178 | + token=self.expired_token, |
| 179 | + auto_refresh_url="https://i.b/refresh", |
| 180 | + token_updater=token_updater, |
| 181 | + ) |
| 182 | + sess.send = lambda r, **kwargs: fake_refresh( |
| 183 | + r=r, status_code=503, kwargs=kwargs, |
| 184 | + ) |
| 185 | + self.assertRaises(HTTPError, sess.get, "https://i.b") |
| 186 | + |
173 | 187 | def fake_refresh_with_auth(r, **kwargs):
|
174 | 188 | if "/refresh" in r.url:
|
175 | 189 | self.assertIn("Authorization", r.headers)
|
@@ -256,6 +270,23 @@ def test_fetch_token(self):
|
256 | 270 | else:
|
257 | 271 | self.assertRaises(OAuth2Error, sess.fetch_token, url)
|
258 | 272 |
|
| 273 | + # test 5xx error responses |
| 274 | + error = {"error": "server error!"} |
| 275 | + for client in self.clients: |
| 276 | + sess = OAuth2Session(client=client, token=self.token) |
| 277 | + sess.send = fake_token(error, status_code=500) |
| 278 | + if isinstance(client, LegacyApplicationClient): |
| 279 | + # this client requires a username+password |
| 280 | + self.assertRaises( |
| 281 | + HTTPError, |
| 282 | + sess.fetch_token, |
| 283 | + url, |
| 284 | + username="username1", |
| 285 | + password="password1", |
| 286 | + ) |
| 287 | + else: |
| 288 | + self.assertRaises(HTTPError, sess.fetch_token, url) |
| 289 | + |
259 | 290 | # there are different scenarios in which the `client_id` can be specified
|
260 | 291 | # reference `oauthlib.tests.oauth2.rfc6749.clients.test_web_application.WebApplicationClientTest.test_prepare_request_body`
|
261 | 292 | # this only needs to test WebApplicationClient
|
|
0 commit comments