Skip to content

Commit 5cc9840

Browse files
lazebnyioctavia-squidington-iii
andauthored
feat(low-code): add profile assertion flow to oauth authenticator component (#236)
Co-authored-by: octavia-squidington-iii <[email protected]>
1 parent 4459243 commit 5cc9840

File tree

9 files changed

+2000176
-47
lines changed

9 files changed

+2000176
-47
lines changed

airbyte_cdk/sources/declarative/auth/oauth.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
#
44

55
from dataclasses import InitVar, dataclass, field
6-
from typing import Any, List, Mapping, Optional, Union
6+
from typing import Any, List, Mapping, MutableMapping, Optional, Union
77

88
import pendulum
99

1010
from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
11+
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
1112
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
1213
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
1314
from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
@@ -44,10 +45,10 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
4445
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
4546
"""
4647

47-
client_id: Union[InterpolatedString, str]
48-
client_secret: Union[InterpolatedString, str]
4948
config: Mapping[str, Any]
5049
parameters: InitVar[Mapping[str, Any]]
50+
client_id: Optional[Union[InterpolatedString, str]] = None
51+
client_secret: Optional[Union[InterpolatedString, str]] = None
5152
token_refresh_endpoint: Optional[Union[InterpolatedString, str]] = None
5253
refresh_token: Optional[Union[InterpolatedString, str]] = None
5354
scopes: Optional[List[str]] = None
@@ -66,6 +67,8 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
6667
grant_type_name: Union[InterpolatedString, str] = "grant_type"
6768
grant_type: Union[InterpolatedString, str] = "refresh_token"
6869
message_repository: MessageRepository = NoopMessageRepository()
70+
profile_assertion: Optional[DeclarativeAuthenticator] = None
71+
use_profile_assertion: Optional[Union[InterpolatedBoolean, str, bool]] = False
6972

7073
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
7174
super().__init__()
@@ -76,11 +79,19 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
7679
else:
7780
self._token_refresh_endpoint = None
7881
self._client_id_name = InterpolatedString.create(self.client_id_name, parameters=parameters)
79-
self._client_id = InterpolatedString.create(self.client_id, parameters=parameters)
82+
self._client_id = (
83+
InterpolatedString.create(self.client_id, parameters=parameters)
84+
if self.client_id
85+
else self.client_id
86+
)
8087
self._client_secret_name = InterpolatedString.create(
8188
self.client_secret_name, parameters=parameters
8289
)
83-
self._client_secret = InterpolatedString.create(self.client_secret, parameters=parameters)
90+
self._client_secret = (
91+
InterpolatedString.create(self.client_secret, parameters=parameters)
92+
if self.client_secret
93+
else self.client_secret
94+
)
8495
self._refresh_token_name = InterpolatedString.create(
8596
self.refresh_token_name, parameters=parameters
8697
)
@@ -99,7 +110,12 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
99110
self.grant_type_name = InterpolatedString.create(
100111
self.grant_type_name, parameters=parameters
101112
)
102-
self.grant_type = InterpolatedString.create(self.grant_type, parameters=parameters)
113+
self.grant_type = InterpolatedString.create(
114+
"urn:ietf:params:oauth:grant-type:jwt-bearer"
115+
if self.use_profile_assertion
116+
else self.grant_type,
117+
parameters=parameters,
118+
)
103119
self._refresh_request_body = InterpolatedMapping(
104120
self.refresh_request_body or {}, parameters=parameters
105121
)
@@ -115,6 +131,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
115131
if self.token_expiry_date
116132
else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints
117133
)
134+
self.use_profile_assertion = (
135+
InterpolatedBoolean(self.use_profile_assertion, parameters=parameters)
136+
if isinstance(self.use_profile_assertion, str)
137+
else self.use_profile_assertion
138+
)
139+
self.assertion_name = "assertion"
140+
118141
if self.access_token_value is not None:
119142
self._access_token_value = InterpolatedString.create(
120143
self.access_token_value, parameters=parameters
@@ -126,9 +149,20 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
126149
self._access_token_value if self.access_token_value else None
127150
)
128151

152+
if not self.use_profile_assertion and any(
153+
client_creds is None for client_creds in [self.client_id, self.client_secret]
154+
):
155+
raise ValueError(
156+
"OAuthAuthenticator configuration error: Both 'client_id' and 'client_secret' are required for the "
157+
"basic OAuth flow."
158+
)
159+
if self.profile_assertion is None and self.use_profile_assertion:
160+
raise ValueError(
161+
"OAuthAuthenticator configuration error: 'profile_assertion' is required when using the profile assertion flow."
162+
)
129163
if self.get_grant_type() == "refresh_token" and self._refresh_token is None:
130164
raise ValueError(
131-
"OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`"
165+
"OAuthAuthenticator configuration error: A 'refresh_token' is required when the 'grant_type' is set to 'refresh_token'."
132166
)
133167

134168
def get_token_refresh_endpoint(self) -> Optional[str]:
@@ -145,19 +179,21 @@ def get_client_id_name(self) -> str:
145179
return self._client_id_name.eval(self.config) # type: ignore # eval returns a string in this context
146180

147181
def get_client_id(self) -> str:
148-
client_id: str = self._client_id.eval(self.config)
182+
client_id = self._client_id.eval(self.config) if self._client_id else self._client_id
149183
if not client_id:
150184
raise ValueError("OAuthAuthenticator was unable to evaluate client_id parameter")
151-
return client_id
185+
return client_id # type: ignore # value will be returned as a string, or an error will be raised
152186

153187
def get_client_secret_name(self) -> str:
154188
return self._client_secret_name.eval(self.config) # type: ignore # eval returns a string in this context
155189

156190
def get_client_secret(self) -> str:
157-
client_secret: str = self._client_secret.eval(self.config)
191+
client_secret = (
192+
self._client_secret.eval(self.config) if self._client_secret else self._client_secret
193+
)
158194
if not client_secret:
159195
raise ValueError("OAuthAuthenticator was unable to evaluate client_secret parameter")
160-
return client_secret
196+
return client_secret # type: ignore # value will be returned as a string, or an error will be raised
161197

162198
def get_refresh_token_name(self) -> str:
163199
return self._refresh_token_name.eval(self.config) # type: ignore # eval returns a string in this context
@@ -192,6 +228,27 @@ def get_token_expiry_date(self) -> pendulum.DateTime:
192228
def set_token_expiry_date(self, value: Union[str, int]) -> None:
193229
self._token_expiry_date = self._parse_token_expiration_date(value)
194230

231+
def get_assertion_name(self) -> str:
232+
return self.assertion_name
233+
234+
def get_assertion(self) -> str:
235+
if self.profile_assertion is None:
236+
raise ValueError("profile_assertion is not set")
237+
return self.profile_assertion.token
238+
239+
def build_refresh_request_body(self) -> Mapping[str, Any]:
240+
"""
241+
Returns the request body to set on the refresh request
242+
243+
Override to define additional parameters
244+
"""
245+
if self.use_profile_assertion:
246+
return {
247+
self.get_grant_type_name(): self.get_grant_type(),
248+
self.get_assertion_name(): self.get_assertion(),
249+
}
250+
return super().build_refresh_request_body()
251+
195252
@property
196253
def access_token(self) -> str:
197254
if self._access_token is None:

airbyte_cdk/sources/declarative/declarative_component_schema.yaml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,8 +1081,6 @@ definitions:
10811081
type: object
10821082
required:
10831083
- type
1084-
- client_id
1085-
- client_secret
10861084
properties:
10871085
type:
10881086
type: string
@@ -1277,6 +1275,15 @@ definitions:
12771275
default: []
12781276
examples:
12791277
- ["invalid_grant", "invalid_permissions"]
1278+
profile_assertion:
1279+
title: Profile Assertion
1280+
description: The authenticator being used to authenticate the client authenticator.
1281+
"$ref": "#/definitions/JwtAuthenticator"
1282+
use_profile_assertion:
1283+
title: Use Profile Assertion
1284+
description: Enable using profile assertion as a flow for OAuth authorization.
1285+
type: boolean
1286+
default: false
12801287
$parameters:
12811288
type: object
12821289
additionalProperties: true

airbyte_cdk/sources/declarative/models/declarative_component_schema.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,8 @@ class OAuthAuthenticator(BaseModel):
506506
examples=["custom_app_id"],
507507
title="Client ID Property Name",
508508
)
509-
client_id: str = Field(
510-
...,
509+
client_id: Optional[str] = Field(
510+
None,
511511
description="The OAuth client ID. Fill it in the user inputs.",
512512
examples=["{{ config['client_id }}", "{{ config['credentials']['client_id }}"],
513513
title="Client ID",
@@ -518,8 +518,8 @@ class OAuthAuthenticator(BaseModel):
518518
examples=["custom_app_secret"],
519519
title="Client Secret Property Name",
520520
)
521-
client_secret: str = Field(
522-
...,
521+
client_secret: Optional[str] = Field(
522+
None,
523523
description="The OAuth client secret. Fill it in the user inputs.",
524524
examples=[
525525
"{{ config['client_secret }}",
@@ -624,6 +624,16 @@ class OAuthAuthenticator(BaseModel):
624624
description="When the token updater is defined, new refresh tokens, access tokens and the access token expiry date are written back from the authentication response to the config object. This is important if the refresh token can only used once.",
625625
title="Token Updater",
626626
)
627+
profile_assertion: Optional[JwtAuthenticator] = Field(
628+
None,
629+
description="The authenticator being used to authenticate the client authenticator.",
630+
title="Profile Assertion",
631+
)
632+
use_profile_assertion: Optional[bool] = Field(
633+
False,
634+
description="Enable using profile assertion as a flow for OAuth authorization.",
635+
title="Use Profile Assertion",
636+
)
627637
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
628638

629639

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2100,6 +2100,12 @@ def create_no_pagination(
21002100
def create_oauth_authenticator(
21012101
self, model: OAuthAuthenticatorModel, config: Config, **kwargs: Any
21022102
) -> DeclarativeOauth2Authenticator:
2103+
profile_assertion = (
2104+
self._create_component_from_model(model.profile_assertion, config=config)
2105+
if model.profile_assertion
2106+
else None
2107+
)
2108+
21032109
if model.refresh_token_updater:
21042110
# ignore type error because fixing it would have a lot of dependencies, revisit later
21052111
return DeclarativeSingleUseRefreshTokenOauth2Authenticator( # type: ignore
@@ -2120,13 +2126,17 @@ def create_oauth_authenticator(
21202126
).eval(config),
21212127
client_id=InterpolatedString.create(
21222128
model.client_id, parameters=model.parameters or {}
2123-
).eval(config),
2129+
).eval(config)
2130+
if model.client_id
2131+
else model.client_id,
21242132
client_secret_name=InterpolatedString.create(
21252133
model.client_secret_name or "client_secret", parameters=model.parameters or {}
21262134
).eval(config),
21272135
client_secret=InterpolatedString.create(
21282136
model.client_secret, parameters=model.parameters or {}
2129-
).eval(config),
2137+
).eval(config)
2138+
if model.client_secret
2139+
else model.client_secret,
21302140
access_token_config_path=model.refresh_token_updater.access_token_config_path,
21312141
refresh_token_config_path=model.refresh_token_updater.refresh_token_config_path,
21322142
token_expiry_date_config_path=model.refresh_token_updater.token_expiry_date_config_path,
@@ -2172,6 +2182,8 @@ def create_oauth_authenticator(
21722182
config=config,
21732183
parameters=model.parameters or {},
21742184
message_repository=self._message_repository,
2185+
profile_assertion=profile_assertion,
2186+
use_profile_assertion=model.use_profile_assertion,
21752187
)
21762188

21772189
def create_offset_increment(

airbyte_cdk/sources/http_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def format_http_message(
4545
log_message["http"]["is_auxiliary"] = is_auxiliary # type: ignore [index]
4646
if stream_name:
4747
log_message["airbyte_cdk"] = {"stream": {"name": stream_name}}
48-
return log_message # type: ignore [return-value] # got "dict[str, object]", expected "dict[str, JsonType]"
48+
return log_message # type: ignore[return-value] # got "dict[str, object]", expected "dict[str, JsonType]"
4949

5050

5151
def _normalize_body_string(body_str: Optional[Union[str, bytes]]) -> Optional[str]:

0 commit comments

Comments
 (0)