3
3
#
4
4
5
5
from dataclasses import InitVar , dataclass , field
6
- from typing import Any , List , Mapping , Optional , Union
6
+ from typing import Any , List , Mapping , MutableMapping , Optional , Union
7
7
8
8
import pendulum
9
9
10
10
from airbyte_cdk .sources .declarative .auth .declarative_authenticator import DeclarativeAuthenticator
11
+ from airbyte_cdk .sources .declarative .interpolation .interpolated_boolean import InterpolatedBoolean
11
12
from airbyte_cdk .sources .declarative .interpolation .interpolated_mapping import InterpolatedMapping
12
13
from airbyte_cdk .sources .declarative .interpolation .interpolated_string import InterpolatedString
13
14
from airbyte_cdk .sources .message import MessageRepository , NoopMessageRepository
@@ -44,10 +45,10 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
44
45
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
45
46
"""
46
47
47
- client_id : Union [InterpolatedString , str ]
48
- client_secret : Union [InterpolatedString , str ]
49
48
config : Mapping [str , Any ]
50
49
parameters : InitVar [Mapping [str , Any ]]
50
+ client_id : Optional [Union [InterpolatedString , str ]] = None
51
+ client_secret : Optional [Union [InterpolatedString , str ]] = None
51
52
token_refresh_endpoint : Optional [Union [InterpolatedString , str ]] = None
52
53
refresh_token : Optional [Union [InterpolatedString , str ]] = None
53
54
scopes : Optional [List [str ]] = None
@@ -66,6 +67,8 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
66
67
grant_type_name : Union [InterpolatedString , str ] = "grant_type"
67
68
grant_type : Union [InterpolatedString , str ] = "refresh_token"
68
69
message_repository : MessageRepository = NoopMessageRepository ()
70
+ profile_assertion : Optional [DeclarativeAuthenticator ] = None
71
+ use_profile_assertion : Optional [Union [InterpolatedBoolean , str , bool ]] = False
69
72
70
73
def __post_init__ (self , parameters : Mapping [str , Any ]) -> None :
71
74
super ().__init__ ()
@@ -76,11 +79,19 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
76
79
else :
77
80
self ._token_refresh_endpoint = None
78
81
self ._client_id_name = InterpolatedString .create (self .client_id_name , parameters = parameters )
79
- self ._client_id = InterpolatedString .create (self .client_id , parameters = parameters )
82
+ self ._client_id = (
83
+ InterpolatedString .create (self .client_id , parameters = parameters )
84
+ if self .client_id
85
+ else self .client_id
86
+ )
80
87
self ._client_secret_name = InterpolatedString .create (
81
88
self .client_secret_name , parameters = parameters
82
89
)
83
- self ._client_secret = InterpolatedString .create (self .client_secret , parameters = parameters )
90
+ self ._client_secret = (
91
+ InterpolatedString .create (self .client_secret , parameters = parameters )
92
+ if self .client_secret
93
+ else self .client_secret
94
+ )
84
95
self ._refresh_token_name = InterpolatedString .create (
85
96
self .refresh_token_name , parameters = parameters
86
97
)
@@ -99,7 +110,12 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
99
110
self .grant_type_name = InterpolatedString .create (
100
111
self .grant_type_name , parameters = parameters
101
112
)
102
- self .grant_type = InterpolatedString .create (self .grant_type , parameters = parameters )
113
+ self .grant_type = InterpolatedString .create (
114
+ "urn:ietf:params:oauth:grant-type:jwt-bearer"
115
+ if self .use_profile_assertion
116
+ else self .grant_type ,
117
+ parameters = parameters ,
118
+ )
103
119
self ._refresh_request_body = InterpolatedMapping (
104
120
self .refresh_request_body or {}, parameters = parameters
105
121
)
@@ -115,6 +131,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
115
131
if self .token_expiry_date
116
132
else pendulum .now ().subtract (days = 1 ) # type: ignore # substract does not have type hints
117
133
)
134
+ self .use_profile_assertion = (
135
+ InterpolatedBoolean (self .use_profile_assertion , parameters = parameters )
136
+ if isinstance (self .use_profile_assertion , str )
137
+ else self .use_profile_assertion
138
+ )
139
+ self .assertion_name = "assertion"
140
+
118
141
if self .access_token_value is not None :
119
142
self ._access_token_value = InterpolatedString .create (
120
143
self .access_token_value , parameters = parameters
@@ -126,9 +149,20 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
126
149
self ._access_token_value if self .access_token_value else None
127
150
)
128
151
152
+ if not self .use_profile_assertion and any (
153
+ client_creds is None for client_creds in [self .client_id , self .client_secret ]
154
+ ):
155
+ raise ValueError (
156
+ "OAuthAuthenticator configuration error: Both 'client_id' and 'client_secret' are required for the "
157
+ "basic OAuth flow."
158
+ )
159
+ if self .profile_assertion is None and self .use_profile_assertion :
160
+ raise ValueError (
161
+ "OAuthAuthenticator configuration error: 'profile_assertion' is required when using the profile assertion flow."
162
+ )
129
163
if self .get_grant_type () == "refresh_token" and self ._refresh_token is None :
130
164
raise ValueError (
131
- "OAuthAuthenticator needs a refresh_token parameter if grant_type is set to ` refresh_token` "
165
+ "OAuthAuthenticator configuration error: A ' refresh_token' is required when the ' grant_type' is set to ' refresh_token'. "
132
166
)
133
167
134
168
def get_token_refresh_endpoint (self ) -> Optional [str ]:
@@ -145,19 +179,21 @@ def get_client_id_name(self) -> str:
145
179
return self ._client_id_name .eval (self .config ) # type: ignore # eval returns a string in this context
146
180
147
181
def get_client_id (self ) -> str :
148
- client_id : str = self ._client_id .eval (self .config )
182
+ client_id = self ._client_id .eval (self .config ) if self . _client_id else self . _client_id
149
183
if not client_id :
150
184
raise ValueError ("OAuthAuthenticator was unable to evaluate client_id parameter" )
151
- return client_id
185
+ return client_id # type: ignore # value will be returned as a string, or an error will be raised
152
186
153
187
def get_client_secret_name (self ) -> str :
154
188
return self ._client_secret_name .eval (self .config ) # type: ignore # eval returns a string in this context
155
189
156
190
def get_client_secret (self ) -> str :
157
- client_secret : str = self ._client_secret .eval (self .config )
191
+ client_secret = (
192
+ self ._client_secret .eval (self .config ) if self ._client_secret else self ._client_secret
193
+ )
158
194
if not client_secret :
159
195
raise ValueError ("OAuthAuthenticator was unable to evaluate client_secret parameter" )
160
- return client_secret
196
+ return client_secret # type: ignore # value will be returned as a string, or an error will be raised
161
197
162
198
def get_refresh_token_name (self ) -> str :
163
199
return self ._refresh_token_name .eval (self .config ) # type: ignore # eval returns a string in this context
@@ -192,6 +228,27 @@ def get_token_expiry_date(self) -> pendulum.DateTime:
192
228
def set_token_expiry_date (self , value : Union [str , int ]) -> None :
193
229
self ._token_expiry_date = self ._parse_token_expiration_date (value )
194
230
231
+ def get_assertion_name (self ) -> str :
232
+ return self .assertion_name
233
+
234
+ def get_assertion (self ) -> str :
235
+ if self .profile_assertion is None :
236
+ raise ValueError ("profile_assertion is not set" )
237
+ return self .profile_assertion .token
238
+
239
+ def build_refresh_request_body (self ) -> Mapping [str , Any ]:
240
+ """
241
+ Returns the request body to set on the refresh request
242
+
243
+ Override to define additional parameters
244
+ """
245
+ if self .use_profile_assertion :
246
+ return {
247
+ self .get_grant_type_name (): self .get_grant_type (),
248
+ self .get_assertion_name (): self .get_assertion (),
249
+ }
250
+ return super ().build_refresh_request_body ()
251
+
195
252
@property
196
253
def access_token (self ) -> str :
197
254
if self ._access_token is None :
0 commit comments