Skip to content

Commit 24cbc51

Browse files
authored
fix: make token expiry optional in OAuth2 response (#462)
if the auth server doesn't return the token expiry, use a default of 1 hour.
1 parent 21b6413 commit 24cbc51

File tree

5 files changed

+218
-88
lines changed

5 files changed

+218
-88
lines changed

airbyte_cdk/sources/declarative/auth/oauth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
239239
def _has_access_token_been_initialized(self) -> bool:
240240
return self._access_token is not None
241241

242-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
243-
self._token_expiry_date = self._parse_token_expiration_date(value)
242+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
243+
self._token_expiry_date = value
244244

245245
def get_assertion_name(self) -> str:
246246
return self.assertion_name

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
130130
headers = self.get_refresh_request_headers()
131131
return headers if headers else None
132132

133-
def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
133+
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
134134
"""
135135
Returns the refresh token and its expiration datetime
136136
@@ -148,6 +148,14 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
148148
# PRIVATE METHODS
149149
# ----------------
150150

151+
def _default_token_expiry_date(self) -> AirbyteDateTime:
152+
"""
153+
Returns the default token expiry date
154+
"""
155+
# 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
156+
default_token_expiry_duration_hours = 1 # 1 hour
157+
return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
158+
151159
def _wrap_refresh_token_exception(
152160
self, exception: requests.exceptions.RequestException
153161
) -> bool:
@@ -257,14 +265,10 @@ def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) ->
257265

258266
def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
259267
"""
260-
Return the expiration datetime of the refresh token
268+
Parse a string or integer token expiration date into a datetime object
261269
262270
:return: expiration datetime
263271
"""
264-
if not value and not self.token_has_expired():
265-
# No expiry token was provided but the previous one is not expired so it's fine
266-
return self.get_token_expiry_date()
267-
268272
if self.token_expiry_is_time_of_expiration:
269273
if not self.token_expiry_date_format:
270274
raise ValueError(
@@ -308,17 +312,30 @@ def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
308312
"""
309313
return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
310314

311-
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
315+
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
312316
"""
313317
Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
314318
319+
If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
320+
315321
Args:
316322
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
317323
318324
Returns:
319-
str: The extracted token_expiry_date.
325+
The extracted token_expiry_date or None if not found.
320326
"""
321-
return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
327+
expires_in = self._find_and_get_value_from_response(
328+
response_data, self.get_expires_in_name()
329+
)
330+
if expires_in is not None:
331+
return self._parse_token_expiration_date(expires_in)
332+
333+
# expires_in is None
334+
existing_expiry_date = self.get_token_expiry_date()
335+
if existing_expiry_date and not self.token_has_expired():
336+
return existing_expiry_date
337+
338+
return self._default_token_expiry_date()
322339

323340
def _find_and_get_value_from_response(
324341
self,
@@ -344,7 +361,7 @@ def _find_and_get_value_from_response(
344361
"""
345362
if current_depth > max_depth:
346363
# this is needed to avoid an inf loop, possible with a very deep nesting observed.
347-
message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response."
364+
message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
348365
raise ResponseKeysMaxRecurtionReached(
349366
internal_message=message, message=message, failure_type=FailureType.config_error
350367
)
@@ -441,7 +458,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
441458
"""Expiration date of the access token"""
442459

443460
@abstractmethod
444-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
461+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
445462
"""Setter for access token expiration date"""
446463

447464
@abstractmethod

airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def get_grant_type(self) -> str:
120120
def get_token_expiry_date(self) -> AirbyteDateTime:
121121
return self._token_expiry_date
122122

123-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
124-
self._token_expiry_date = self._parse_token_expiration_date(value)
123+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
124+
self._token_expiry_date = value
125125

126126
@property
127127
def token_expiry_is_time_of_expiration(self) -> bool:
@@ -316,26 +316,6 @@ def token_has_expired(self) -> bool:
316316
"""Returns True if the token is expired"""
317317
return ab_datetime_now() > self.get_token_expiry_date()
318318

319-
@staticmethod
320-
def get_new_token_expiry_date(
321-
access_token_expires_in: str,
322-
token_expiry_date_format: str | None = None,
323-
) -> AirbyteDateTime:
324-
"""
325-
Calculate the new token expiry date based on the provided expiration duration or format.
326-
327-
Args:
328-
access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format.
329-
token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None.
330-
331-
Returns:
332-
AirbyteDateTime: The calculated expiry date of the access token.
333-
"""
334-
if token_expiry_date_format:
335-
return ab_datetime_parse(access_token_expires_in)
336-
else:
337-
return ab_datetime_now() + timedelta(seconds=int(access_token_expires_in))
338-
339319
def get_access_token(self) -> str:
340320
"""Retrieve new access and refresh token if the access token has expired.
341321
The new refresh token is persisted with the set_refresh_token function
@@ -346,16 +326,13 @@ def get_access_token(self) -> str:
346326
new_access_token, access_token_expires_in, new_refresh_token = (
347327
self.refresh_access_token()
348328
)
349-
new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date(
350-
access_token_expires_in, self._token_expiry_date_format
351-
)
352329
self.access_token = new_access_token
353330
self.set_refresh_token(new_refresh_token)
354-
self.set_token_expiry_date(new_token_expiry_date)
331+
self.set_token_expiry_date(access_token_expires_in)
355332
self._emit_control_message()
356333
return self.access_token
357334

358-
def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override]
335+
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
359336
"""
360337
Refreshes the access token by making a handled request and extracting the necessary token information.
361338

unit_tests/sources/declarative/auth/test_oauth.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def test_error_on_refresh_token_grant_without_refresh_token(self):
203203
grant_type="refresh_token",
204204
)
205205

206+
@freezegun.freeze_time("2022-01-01")
206207
def test_refresh_access_token(self, mocker):
207208
oauth = DeclarativeOauth2Authenticator(
208209
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
@@ -225,13 +226,15 @@ def test_refresh_access_token(self, mocker):
225226
resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}
226227
)
227228
mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True)
228-
token = oauth.refresh_access_token()
229+
access_token, token_expiry_date = oauth.refresh_access_token()
229230

230-
assert ("access_token", 1000) == token
231+
assert access_token == "access_token"
232+
assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)
231233

232234
filtered = filter_secrets("access_token")
233235
assert filtered == "****"
234236

237+
@freezegun.freeze_time("2022-01-01")
235238
def test_refresh_access_token_when_headers_provided(self, mocker):
236239
expected_headers = {
237240
"Authorization": "Bearer some_access_token",
@@ -256,9 +259,10 @@ def test_refresh_access_token_when_headers_provided(self, mocker):
256259
mocked_request = mocker.patch.object(
257260
requests, "request", side_effect=mock_request, autospec=True
258261
)
259-
token = oauth.refresh_access_token()
262+
access_token, token_expiry_date = oauth.refresh_access_token()
260263

261-
assert ("access_token", 1000) == token
264+
assert access_token == "access_token"
265+
assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)
262266

263267
assert mocked_request.call_args.kwargs["headers"] == expected_headers
264268

@@ -314,6 +318,7 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp(
314318
assert isinstance(oauth._token_expiry_date, AirbyteDateTime)
315319
assert oauth.get_token_expiry_date() == ab_datetime_parse(expected_date)
316320

321+
@freezegun.freeze_time("2022-01-01")
317322
def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_fetch_access_token(
318323
self,
319324
) -> None:
@@ -335,12 +340,65 @@ def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_
335340
url="https://refresh_endpoint.com/",
336341
body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token",
337342
),
338-
HttpResponse(body=json.dumps({"access_token": "new_access_token"})),
343+
HttpResponse(
344+
body=json.dumps({"access_token": "new_access_token", "expires_in": 1000})
345+
),
339346
)
340347
oauth.get_access_token()
341348

342349
assert oauth.access_token == "new_access_token"
343-
assert oauth._token_expiry_date == expiry_date
350+
assert oauth._token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)
351+
352+
@freezegun.freeze_time("2022-01-01")
353+
@pytest.mark.parametrize(
354+
"initial_expiry_date_delta, expected_new_expiry_date_delta, expected_access_token",
355+
[
356+
(timedelta(days=1), timedelta(days=1), "some_access_token"),
357+
(timedelta(days=-1), timedelta(hours=1), "new_access_token"),
358+
(None, timedelta(hours=1), "new_access_token"),
359+
],
360+
ids=[
361+
"initial_expiry_date_in_future",
362+
"initial_expiry_date_in_past",
363+
"no_initial_expiry_date",
364+
],
365+
)
366+
def test_no_expiry_date_provided_by_auth_server(
367+
self,
368+
initial_expiry_date_delta,
369+
expected_new_expiry_date_delta,
370+
expected_access_token,
371+
) -> None:
372+
initial_expiry_date = (
373+
ab_datetime_now().add(initial_expiry_date_delta).isoformat()
374+
if initial_expiry_date_delta
375+
else None
376+
)
377+
expected_new_expiry_date = ab_datetime_now().add(expected_new_expiry_date_delta)
378+
oauth = DeclarativeOauth2Authenticator(
379+
token_refresh_endpoint="https://refresh_endpoint.com/",
380+
client_id="some_client_id",
381+
client_secret="some_client_secret",
382+
token_expiry_date=initial_expiry_date,
383+
access_token_value="some_access_token",
384+
refresh_token="some_refresh_token",
385+
config={},
386+
parameters={},
387+
grant_type="client",
388+
)
389+
390+
with HttpMocker() as http_mocker:
391+
http_mocker.post(
392+
HttpRequest(
393+
url="https://refresh_endpoint.com/",
394+
body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token",
395+
),
396+
HttpResponse(body=json.dumps({"access_token": "new_access_token"})),
397+
)
398+
oauth.get_access_token()
399+
400+
assert oauth.access_token == expected_access_token
401+
assert oauth._token_expiry_date == expected_new_expiry_date
344402

345403
@pytest.mark.parametrize(
346404
"expires_in_response, token_expiry_date_format",
@@ -443,6 +501,7 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next
443501
assert "access_token" == token
444502
assert oauth.get_token_expiry_date() == ab_datetime_parse(next_day)
445503

504+
@freezegun.freeze_time("2022-01-01")
446505
def test_profile_assertion(self, mocker):
447506
with HttpMocker() as http_mocker:
448507
jwt = JwtAuthenticator(
@@ -477,7 +536,7 @@ def test_profile_assertion(self, mocker):
477536

478537
token = oauth.refresh_access_token()
479538

480-
assert ("access_token", 1000) == token
539+
assert ("access_token", ab_datetime_now().add(timedelta(seconds=1000))) == token
481540

482541
filtered = filter_secrets("access_token")
483542
assert filtered == "****"

0 commit comments

Comments
 (0)