Skip to content

Commit ae1211f

Browse files
author
Baz
authored
fix: (DeclarativeOAuthFlow) - allow DeclarativeOauth2Authenticator to use access_token directly when no token_refresh_endpoint or refresh_token values are provided. (#182)
1 parent f8054a8 commit ae1211f

File tree

6 files changed

+75
-19
lines changed

6 files changed

+75
-19
lines changed

airbyte_cdk/sources/declarative/auth/oauth.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,28 +43,32 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
4343
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
4444
"""
4545

46-
token_refresh_endpoint: Union[InterpolatedString, str]
4746
client_id: Union[InterpolatedString, str]
4847
client_secret: Union[InterpolatedString, str]
4948
config: Mapping[str, Any]
5049
parameters: InitVar[Mapping[str, Any]]
50+
token_refresh_endpoint: Optional[Union[InterpolatedString, str]] = None
5151
refresh_token: Optional[Union[InterpolatedString, str]] = None
5252
scopes: Optional[List[str]] = None
5353
token_expiry_date: Optional[Union[InterpolatedString, str]] = None
5454
_token_expiry_date: Optional[pendulum.DateTime] = field(init=False, repr=False, default=None)
5555
token_expiry_date_format: Optional[str] = None
5656
token_expiry_is_time_of_expiration: bool = False
5757
access_token_name: Union[InterpolatedString, str] = "access_token"
58+
access_token_value: Optional[Union[InterpolatedString, str]] = None
5859
expires_in_name: Union[InterpolatedString, str] = "expires_in"
5960
refresh_request_body: Optional[Mapping[str, Any]] = None
6061
grant_type: Union[InterpolatedString, str] = "refresh_token"
6162
message_repository: MessageRepository = NoopMessageRepository()
6263

6364
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
6465
super().__init__()
65-
self._token_refresh_endpoint = InterpolatedString.create(
66-
self.token_refresh_endpoint, parameters=parameters
67-
)
66+
if self.token_refresh_endpoint is not None:
67+
self._token_refresh_endpoint: Optional[InterpolatedString] = InterpolatedString.create(
68+
self.token_refresh_endpoint, parameters=parameters
69+
)
70+
else:
71+
self._token_refresh_endpoint = None
6872
self._client_id = InterpolatedString.create(self.client_id, parameters=parameters)
6973
self._client_secret = InterpolatedString.create(self.client_secret, parameters=parameters)
7074
if self.refresh_token is not None:
@@ -92,20 +96,31 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
9296
if self.token_expiry_date
9397
else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints
9498
)
95-
self._access_token: Optional[str] = None # access_token is initialized by a setter
99+
if self.access_token_value is not None:
100+
self._access_token_value = InterpolatedString.create(
101+
self.access_token_value, parameters=parameters
102+
).eval(self.config)
103+
else:
104+
self._access_token_value = None
105+
106+
self._access_token: Optional[str] = (
107+
self._access_token_value if self.access_token_value else None
108+
)
96109

97110
if self.get_grant_type() == "refresh_token" and self._refresh_token is None:
98111
raise ValueError(
99112
"OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`"
100113
)
101114

102-
def get_token_refresh_endpoint(self) -> str:
103-
refresh_token: str = self._token_refresh_endpoint.eval(self.config)
104-
if not refresh_token:
105-
raise ValueError(
106-
"OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter"
107-
)
108-
return refresh_token
115+
def get_token_refresh_endpoint(self) -> Optional[str]:
116+
if self._token_refresh_endpoint is not None:
117+
refresh_token_endpoint: str = self._token_refresh_endpoint.eval(self.config)
118+
if not refresh_token_endpoint:
119+
raise ValueError(
120+
"OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter"
121+
)
122+
return refresh_token_endpoint
123+
return None
109124

110125
def get_client_id(self) -> str:
111126
client_id: str = self._client_id.eval(self.config)

airbyte_cdk/sources/declarative/declarative_component_schema.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,6 @@ definitions:
10211021
- type
10221022
- client_id
10231023
- client_secret
1024-
- token_refresh_endpoint
10251024
properties:
10261025
type:
10271026
type: string
@@ -1060,6 +1059,12 @@ definitions:
10601059
default: "access_token"
10611060
examples:
10621061
- access_token
1062+
access_token_value:
1063+
title: Access Token Value
1064+
description: The value of the access_token to bypass the token refreshing using `refresh_token`.
1065+
type: string
1066+
examples:
1067+
- secret_access_token_value
10631068
expires_in_name:
10641069
title: Token Expiry Property Name
10651070
description: The name of the property which contains the expiry date in the response from the token refresh endpoint.

airbyte_cdk/sources/declarative/models/declarative_component_schema.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,8 @@ class OAuthAuthenticator(BaseModel):
489489
],
490490
title="Refresh Token",
491491
)
492-
token_refresh_endpoint: str = Field(
493-
...,
492+
token_refresh_endpoint: Optional[str] = Field(
493+
None,
494494
description="The full URL to call to obtain a new access token.",
495495
examples=["https://connect.squareup.com/oauth2/token"],
496496
title="Token Refresh Endpoint",
@@ -501,6 +501,12 @@ class OAuthAuthenticator(BaseModel):
501501
examples=["access_token"],
502502
title="Access Token Property Name",
503503
)
504+
access_token_value: Optional[str] = Field(
505+
None,
506+
description="The value of the access_token to bypass the token refreshing using `refresh_token`.",
507+
examples=["secret_access_token_value"],
508+
title="Access Token Value",
509+
)
504510
expires_in_name: Optional[str] = Field(
505511
"expires_in",
506512
description="The name of the property which contains the expiry date in the response from the token refresh endpoint.",

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1800,7 +1800,8 @@ def create_oauth_authenticator(
18001800
return DeclarativeSingleUseRefreshTokenOauth2Authenticator( # type: ignore
18011801
config,
18021802
InterpolatedString.create(
1803-
model.token_refresh_endpoint, parameters=model.parameters or {}
1803+
model.token_refresh_endpoint, # type: ignore
1804+
parameters=model.parameters or {},
18041805
).eval(config),
18051806
access_token_name=InterpolatedString.create(
18061807
model.access_token_name or "access_token", parameters=model.parameters or {}
@@ -1834,6 +1835,7 @@ def create_oauth_authenticator(
18341835
# ignore type error because fixing it would have a lot of dependencies, revisit later
18351836
return DeclarativeOauth2Authenticator( # type: ignore
18361837
access_token_name=model.access_token_name or "access_token",
1838+
access_token_value=model.access_token_value,
18371839
client_id=model.client_id,
18381840
client_secret=model.client_secret,
18391841
expires_in_name=model.expires_in_name or "expires_in",

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,16 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques
5454

5555
def get_auth_header(self) -> Mapping[str, Any]:
5656
"""HTTP header to set on the requests"""
57-
return {"Authorization": f"Bearer {self.get_access_token()}"}
57+
token = (
58+
self.access_token
59+
if (
60+
not self.get_token_refresh_endpoint()
61+
or not self.get_refresh_token()
62+
and self.access_token
63+
)
64+
else self.get_access_token()
65+
)
66+
return {"Authorization": f"Bearer {token}"}
5867

5968
def get_access_token(self) -> str:
6069
"""Returns the access token"""
@@ -121,7 +130,7 @@ def _get_refresh_access_token_response(self) -> Any:
121130
try:
122131
response = requests.request(
123132
method="POST",
124-
url=self.get_token_refresh_endpoint(),
133+
url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected.
125134
data=self.build_refresh_request_body(),
126135
)
127136
if response.ok:
@@ -198,7 +207,7 @@ def token_expiry_date_format(self) -> Optional[str]:
198207
return None
199208

200209
@abstractmethod
201-
def get_token_refresh_endpoint(self) -> str:
210+
def get_token_refresh_endpoint(self) -> Optional[str]:
202211
"""Returns the endpoint to refresh the access token"""
203212

204213
@abstractmethod

unit_tests/sources/declarative/auth/test_oauth.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"custom_field": "in_outbound_request",
2727
"another_field": "exists_in_body",
2828
"grant_type": "some_grant_type",
29+
"access_token": "some_access_token",
2930
}
3031
parameters = {"refresh_token": "some_refresh_token"}
3132

@@ -129,6 +130,24 @@ def test_refresh_without_refresh_token(self):
129130
}
130131
assert body == expected
131132

133+
def test_get_auth_header_without_refresh_token_and_without_refresh_token_endpoint(self):
134+
"""
135+
Coverred the case when the `access_token_value` is supplied,
136+
without `token_refresh_endpoint` or `refresh_token` provided.
137+
138+
In this case, it's expected to have the `access_token_value` provided to return the permanent `auth header`,
139+
contains the authentication.
140+
"""
141+
oauth = DeclarativeOauth2Authenticator(
142+
access_token_value="{{ config['access_token'] }}",
143+
client_id="{{ config['client_id'] }}",
144+
client_secret="{{ config['client_secret'] }}",
145+
config=config,
146+
parameters={},
147+
grant_type="client_credentials",
148+
)
149+
assert oauth.get_auth_header() == {"Authorization": "Bearer some_access_token"}
150+
132151
def test_error_on_refresh_token_grant_without_refresh_token(self):
133152
"""
134153
Should throw an error if grant_type refresh_token is configured without refresh_token.

0 commit comments

Comments
 (0)