Skip to content

Commit b3decdb

Browse files
authored
Merge pull request #433 from violuke/master
Adding refresh_token_request and access_token_request compliance hooks
2 parents 6ac6133 + 4e67803 commit b3decdb

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

requests_oauthlib/oauth2_session.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def __init__(
9595
"access_token_response": set(),
9696
"refresh_token_response": set(),
9797
"protected_request": set(),
98+
"refresh_token_request": set(),
99+
"access_token_request": set(),
98100
}
99101

100102
@property
@@ -352,6 +354,12 @@ def fetch_token(
352354
else:
353355
raise ValueError("The method kwarg must be POST or GET.")
354356

357+
for hook in self.compliance_hook["access_token_request"]:
358+
log.debug("Invoking access_token_request hook %s.", hook)
359+
token_url, headers, request_kwargs = hook(
360+
token_url, headers, request_kwargs
361+
)
362+
355363
r = self.request(
356364
method=method,
357365
url=token_url,
@@ -443,6 +451,10 @@ def refresh_token(
443451
"Content-Type": ("application/x-www-form-urlencoded"),
444452
}
445453

454+
for hook in self.compliance_hook["refresh_token_request"]:
455+
log.debug("Invoking refresh_token_request hook %s.", hook)
456+
token_url, headers, body = hook(token_url, headers, body)
457+
446458
r = self.post(
447459
token_url,
448460
data=dict(urldecode(body)),
@@ -544,6 +556,8 @@ def register_compliance_hook(self, hook_type, hook):
544556
access_token_response invoked before token parsing.
545557
refresh_token_response invoked before refresh token parsing.
546558
protected_request invoked before making a request.
559+
access_token_request invoked before making a token fetch request.
560+
refresh_token_request invoked before making a refresh request.
547561
548562
If you find a new hook is needed please send a GitHub PR request
549563
or open an issue.

tests/test_compliance_fixes.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,60 @@ def test_fetch_access_token(self):
332332
authorization_response="https://i.b/?code=hello",
333333
)
334334
assert token["token_type"] == "Bearer"
335+
336+
337+
def access_and_refresh_token_request_compliance_fix_test(session, client_secret):
338+
def _non_compliant_header(url, headers, body):
339+
headers["X-Client-Secret"] = client_secret
340+
return url, headers, body
341+
342+
session.register_compliance_hook("access_token_request", _non_compliant_header)
343+
session.register_compliance_hook("refresh_token_request", _non_compliant_header)
344+
return session
345+
346+
347+
class RefreshTokenRequestComplianceFixTest(TestCase):
348+
value_to_test_for = "value_to_test_for"
349+
350+
def setUp(self):
351+
mocker = requests_mock.Mocker()
352+
mocker.post(
353+
"https://example.com/token",
354+
request_headers={"X-Client-Secret": self.value_to_test_for},
355+
json={
356+
"access_token": "this is the access token",
357+
"expires_in": 7200,
358+
"token_type": "Bearer",
359+
},
360+
headers={"Content-Type": "application/json"},
361+
)
362+
mocker.post(
363+
"https://example.com/refresh",
364+
request_headers={"X-Client-Secret": self.value_to_test_for},
365+
json={
366+
"access_token": "this is the access token",
367+
"expires_in": 7200,
368+
"token_type": "Bearer",
369+
},
370+
headers={"Content-Type": "application/json"},
371+
)
372+
mocker.start()
373+
self.addCleanup(mocker.stop)
374+
375+
session = OAuth2Session()
376+
self.fixed_session = access_and_refresh_token_request_compliance_fix_test(
377+
session, self.value_to_test_for
378+
)
379+
380+
def test_access_token(self):
381+
token = self.fixed_session.fetch_token(
382+
"https://example.com/token",
383+
authorization_response="https://i.b/?code=hello",
384+
)
385+
assert token["token_type"] == "Bearer"
386+
387+
def test_refresh_token(self):
388+
token = self.fixed_session.refresh_token(
389+
"https://example.com/refresh",
390+
)
391+
assert token["token_type"] == "Bearer"

0 commit comments

Comments
 (0)