Skip to content

Commit 62ad2ca

Browse files
authored
Replace Deprecated (Current) OAuth2 Handling with AuthManager Implementation LegacyOAuth2AuthManager (#1981)
<!-- Thanks for opening a pull request! --> <!-- In the case this PR will resolve an issue, please replace ${GITHUB_ISSUE_ID} below with the actual Github issue id. --> <!-- Closes #1909 --> # Rationale for this change Replace existing Auth handling with `LegacyOAuth2AuthManager`. Tracking issue: #1909 There will be follow up PRs to this PR that will address the following: - introduce a mechanism for using a custom `AuthManager` implementation, along with the ability to use a set of config parameters - introduce a `OAuth2AuthManager` that more closely follows the OAuth2 protocol, and also uses a separate thread to proactively refreshes the token, rather than reactively refreshing the token on `UnAuthorizedError` or the deprecated `AuthorizationExpiredError`. # Are these changes tested? Yes, both through unit and integration tests # Are there any user-facing changes? Yes - previously, if `TOKEN` and `CREDENTIAL` are both defined, `oauth/tokens` endpoint wouldn't be used to refresh the token with client credentials when the `RestCatalog` was initialized. However, `oauth/tokens` endpoint would be used on retries that handled 401 or 419 error. This erratic behavior will now be updated as follows: - if `CREDENTIAL` is defined, `oauth/tokens` endpoint will be used to fetch the access token using the client credentials both when the RestCatalog is initialized, and when the refresh_tokens call is made as a reaction to 401 or 419 error. - if both `CREDENTIAL` and `TOKEN` are defined, we will follow the above behavior. - if only `TOKEN` is defined, the initial token will be used instead <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent ff7bc62 commit 62ad2ca

File tree

4 files changed

+275
-144
lines changed

4 files changed

+275
-144
lines changed

pyiceberg/catalog/rest/__init__.py

Lines changed: 49 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,18 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
from enum import Enum
18-
from json import JSONDecodeError
1918
from typing import (
2019
TYPE_CHECKING,
2120
Any,
2221
Dict,
2322
List,
24-
Literal,
2523
Optional,
2624
Set,
2725
Tuple,
28-
Type,
2926
Union,
3027
)
3128

32-
from pydantic import Field, ValidationError, field_validator
29+
from pydantic import Field, field_validator
3330
from requests import HTTPError, Session
3431
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
3532

@@ -41,22 +38,18 @@
4138
Catalog,
4239
PropertiesUpdateSummary,
4340
)
41+
from pyiceberg.catalog.rest.auth import AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
42+
from pyiceberg.catalog.rest.response import _handle_non_200_response
4443
from pyiceberg.exceptions import (
4544
AuthorizationExpiredError,
46-
BadRequestError,
4745
CommitFailedException,
4846
CommitStateUnknownException,
49-
ForbiddenError,
5047
NamespaceAlreadyExistsError,
5148
NamespaceNotEmptyError,
5249
NoSuchIdentifierError,
5350
NoSuchNamespaceError,
5451
NoSuchTableError,
5552
NoSuchViewError,
56-
OAuthError,
57-
RESTError,
58-
ServerError,
59-
ServiceUnavailableError,
6053
TableAlreadyExistsError,
6154
UnauthorizedError,
6255
)
@@ -182,15 +175,6 @@ class RegisterTableRequest(IcebergBaseModel):
182175
metadata_location: str = Field(..., alias="metadata-location")
183176

184177

185-
class TokenResponse(IcebergBaseModel):
186-
access_token: str = Field()
187-
token_type: str = Field()
188-
expires_in: Optional[int] = Field(default=None)
189-
issued_token_type: Optional[str] = Field(default=None)
190-
refresh_token: Optional[str] = Field(default=None)
191-
scope: Optional[str] = Field(default=None)
192-
193-
194178
class ConfigResponse(IcebergBaseModel):
195179
defaults: Properties = Field()
196180
overrides: Properties = Field()
@@ -229,24 +213,6 @@ class ListViewsResponse(IcebergBaseModel):
229213
identifiers: List[ListViewResponseEntry] = Field()
230214

231215

232-
class ErrorResponseMessage(IcebergBaseModel):
233-
message: str = Field()
234-
type: str = Field()
235-
code: int = Field()
236-
237-
238-
class ErrorResponse(IcebergBaseModel):
239-
error: ErrorResponseMessage = Field()
240-
241-
242-
class OAuthErrorResponse(IcebergBaseModel):
243-
error: Literal[
244-
"invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope"
245-
]
246-
error_description: Optional[str] = None
247-
error_uri: Optional[str] = None
248-
249-
250216
class RestCatalog(Catalog):
251217
uri: str
252218
_session: Session
@@ -279,8 +245,7 @@ def _create_session(self) -> Session:
279245
elif ssl_client_cert := ssl_client.get(CERT):
280246
session.cert = ssl_client_cert
281247

282-
self._refresh_token(session, self.properties.get(TOKEN))
283-
248+
session.auth = AuthManagerAdapter(self._create_legacy_oauth2_auth_manager(session))
284249
# Set HTTP headers
285250
self._config_headers(session)
286251

@@ -290,6 +255,26 @@ def _create_session(self) -> Session:
290255

291256
return session
292257

258+
def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager:
259+
"""Create the LegacyOAuth2AuthManager by fetching required properties.
260+
261+
This will be removed in PyIceberg 1.0
262+
"""
263+
client_credentials = self.properties.get(CREDENTIAL)
264+
# We want to call `self.auth_url` only when we are using CREDENTIAL
265+
# with the legacy OAUTH2 flow as it will raise a DeprecationWarning
266+
auth_url = self.auth_url if client_credentials is not None else None
267+
268+
auth_config = {
269+
"session": session,
270+
"auth_url": auth_url,
271+
"credential": client_credentials,
272+
"initial_token": self.properties.get(TOKEN),
273+
"optional_oauth_params": self._extract_optional_oauth_params(),
274+
}
275+
276+
return AuthManagerFactory.create("legacyoauth2", auth_config)
277+
293278
def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier:
294279
"""Check if the identifier has at least one element."""
295280
identifier_tuple = Catalog.identifier_to_tuple(identifier)
@@ -352,27 +337,6 @@ def _extract_optional_oauth_params(self) -> Dict[str, str]:
352337

353338
return optional_oauth_param
354339

355-
def _fetch_access_token(self, session: Session, credential: str) -> str:
356-
if SEMICOLON in credential:
357-
client_id, client_secret = credential.split(SEMICOLON)
358-
else:
359-
client_id, client_secret = None, credential
360-
361-
data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret}
362-
363-
optional_oauth_params = self._extract_optional_oauth_params()
364-
data.update(optional_oauth_params)
365-
366-
response = session.post(
367-
url=self.auth_url, data=data, headers={**session.headers, "Content-type": "application/x-www-form-urlencoded"}
368-
)
369-
try:
370-
response.raise_for_status()
371-
except HTTPError as exc:
372-
self._handle_non_200_response(exc, {400: OAuthError, 401: OAuthError})
373-
374-
return TokenResponse.model_validate_json(response.text).access_token
375-
376340
def _fetch_config(self) -> None:
377341
params = {}
378342
if warehouse_location := self.properties.get(WAREHOUSE_LOCATION):
@@ -383,7 +347,7 @@ def _fetch_config(self) -> None:
383347
try:
384348
response.raise_for_status()
385349
except HTTPError as exc:
386-
self._handle_non_200_response(exc, {})
350+
_handle_non_200_response(exc, {})
387351
config_response = ConfigResponse.model_validate_json(response.text)
388352

389353
config = config_response.defaults
@@ -413,58 +377,6 @@ def _split_identifier_for_json(self, identifier: Union[str, Identifier]) -> Dict
413377
identifier_tuple = self._identifier_to_validated_tuple(identifier)
414378
return {"namespace": identifier_tuple[:-1], "name": identifier_tuple[-1]}
415379

416-
def _handle_non_200_response(self, exc: HTTPError, error_handler: Dict[int, Type[Exception]]) -> None:
417-
exception: Type[Exception]
418-
419-
if exc.response is None:
420-
raise ValueError("Did not receive a response")
421-
422-
code = exc.response.status_code
423-
if code in error_handler:
424-
exception = error_handler[code]
425-
elif code == 400:
426-
exception = BadRequestError
427-
elif code == 401:
428-
exception = UnauthorizedError
429-
elif code == 403:
430-
exception = ForbiddenError
431-
elif code == 422:
432-
exception = RESTError
433-
elif code == 419:
434-
exception = AuthorizationExpiredError
435-
elif code == 501:
436-
exception = NotImplementedError
437-
elif code == 503:
438-
exception = ServiceUnavailableError
439-
elif 500 <= code < 600:
440-
exception = ServerError
441-
else:
442-
exception = RESTError
443-
444-
try:
445-
if exception == OAuthError:
446-
# The OAuthErrorResponse has a different format
447-
error = OAuthErrorResponse.model_validate_json(exc.response.text)
448-
response = str(error.error)
449-
if description := error.error_description:
450-
response += f": {description}"
451-
if uri := error.error_uri:
452-
response += f" ({uri})"
453-
else:
454-
error = ErrorResponse.model_validate_json(exc.response.text).error
455-
response = f"{error.type}: {error.message}"
456-
except JSONDecodeError:
457-
# In the case we don't have a proper response
458-
response = f"RESTError {exc.response.status_code}: Could not decode json payload: {exc.response.text}"
459-
except ValidationError as e:
460-
# In the case we don't have a proper response
461-
errs = ", ".join(err["msg"] for err in e.errors())
462-
response = (
463-
f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}"
464-
)
465-
466-
raise exception(response) from exc
467-
468380
def _init_sigv4(self, session: Session) -> None:
469381
from urllib import parse
470382

@@ -534,16 +446,13 @@ def _response_to_staged_table(self, identifier_tuple: Tuple[str, ...], table_res
534446
catalog=self,
535447
)
536448

537-
def _refresh_token(self, session: Optional[Session] = None, initial_token: Optional[str] = None) -> None:
538-
session = session or self._session
539-
if initial_token is not None:
540-
self.properties[TOKEN] = initial_token
541-
elif CREDENTIAL in self.properties:
542-
self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])
543-
544-
# Set Auth token for subsequent calls in the session
545-
if token := self.properties.get(TOKEN):
546-
session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
449+
def _refresh_token(self) -> None:
450+
# Reactive token refresh is atypical - we should proactively refresh tokens in a separate thread
451+
# instead of retrying on Auth Exceptions. Keeping refresh behavior for the LegacyOAuth2AuthManager
452+
# for backward compatibility
453+
auth_manager = self._session.auth.auth_manager # type: ignore[union-attr]
454+
if isinstance(auth_manager, LegacyOAuth2AuthManager):
455+
auth_manager._refresh_token()
547456

548457
def _config_headers(self, session: Session) -> None:
549458
header_properties = get_header_properties(self.properties)
@@ -588,7 +497,7 @@ def _create_table(
588497
try:
589498
response.raise_for_status()
590499
except HTTPError as exc:
591-
self._handle_non_200_response(exc, {409: TableAlreadyExistsError})
500+
_handle_non_200_response(exc, {409: TableAlreadyExistsError})
592501
return TableResponse.model_validate_json(response.text)
593502

594503
@retry(**_RETRY_ARGS)
@@ -661,7 +570,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
661570
try:
662571
response.raise_for_status()
663572
except HTTPError as exc:
664-
self._handle_non_200_response(exc, {409: TableAlreadyExistsError})
573+
_handle_non_200_response(exc, {409: TableAlreadyExistsError})
665574

666575
table_response = TableResponse.model_validate_json(response.text)
667576
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)
@@ -674,7 +583,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
674583
try:
675584
response.raise_for_status()
676585
except HTTPError as exc:
677-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
586+
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
678587
return [(*table.namespace, table.name) for table in ListTablesResponse.model_validate_json(response.text).identifiers]
679588

680589
@retry(**_RETRY_ARGS)
@@ -692,7 +601,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
692601
try:
693602
response.raise_for_status()
694603
except HTTPError as exc:
695-
self._handle_non_200_response(exc, {404: NoSuchTableError})
604+
_handle_non_200_response(exc, {404: NoSuchTableError})
696605

697606
table_response = TableResponse.model_validate_json(response.text)
698607
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)
@@ -705,7 +614,7 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool =
705614
try:
706615
response.raise_for_status()
707616
except HTTPError as exc:
708-
self._handle_non_200_response(exc, {404: NoSuchTableError})
617+
_handle_non_200_response(exc, {404: NoSuchTableError})
709618

710619
@retry(**_RETRY_ARGS)
711620
def purge_table(self, identifier: Union[str, Identifier]) -> None:
@@ -721,7 +630,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
721630
try:
722631
response.raise_for_status()
723632
except HTTPError as exc:
724-
self._handle_non_200_response(exc, {404: NoSuchTableError, 409: TableAlreadyExistsError})
633+
_handle_non_200_response(exc, {404: NoSuchTableError, 409: TableAlreadyExistsError})
725634

726635
return self.load_table(to_identifier)
727636

@@ -744,7 +653,7 @@ def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]:
744653
try:
745654
response.raise_for_status()
746655
except HTTPError as exc:
747-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
656+
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
748657
return [(*view.namespace, view.name) for view in ListViewsResponse.model_validate_json(response.text).identifiers]
749658

750659
@retry(**_RETRY_ARGS)
@@ -782,7 +691,7 @@ def commit_table(
782691
try:
783692
response.raise_for_status()
784693
except HTTPError as exc:
785-
self._handle_non_200_response(
694+
_handle_non_200_response(
786695
exc,
787696
{
788697
409: CommitFailedException,
@@ -801,7 +710,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper
801710
try:
802711
response.raise_for_status()
803712
except HTTPError as exc:
804-
self._handle_non_200_response(exc, {409: NamespaceAlreadyExistsError})
713+
_handle_non_200_response(exc, {409: NamespaceAlreadyExistsError})
805714

806715
@retry(**_RETRY_ARGS)
807716
def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
@@ -811,7 +720,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
811720
try:
812721
response.raise_for_status()
813722
except HTTPError as exc:
814-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError})
723+
_handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError})
815724

816725
@retry(**_RETRY_ARGS)
817726
def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
@@ -826,7 +735,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi
826735
try:
827736
response.raise_for_status()
828737
except HTTPError as exc:
829-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
738+
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
830739

831740
return ListNamespaceResponse.model_validate_json(response.text).namespaces
832741

@@ -838,7 +747,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper
838747
try:
839748
response.raise_for_status()
840749
except HTTPError as exc:
841-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
750+
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
842751

843752
return NamespaceResponse.model_validate_json(response.text).properties
844753

@@ -853,7 +762,7 @@ def update_namespace_properties(
853762
try:
854763
response.raise_for_status()
855764
except HTTPError as exc:
856-
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
765+
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
857766
parsed_response = UpdateNamespacePropertiesResponse.model_validate_json(response.text)
858767
return PropertiesUpdateSummary(
859768
removed=parsed_response.removed,
@@ -875,7 +784,7 @@ def namespace_exists(self, namespace: Union[str, Identifier]) -> bool:
875784
try:
876785
response.raise_for_status()
877786
except HTTPError as exc:
878-
self._handle_non_200_response(exc, {})
787+
_handle_non_200_response(exc, {})
879788

880789
return False
881790

@@ -901,7 +810,7 @@ def table_exists(self, identifier: Union[str, Identifier]) -> bool:
901810
try:
902811
response.raise_for_status()
903812
except HTTPError as exc:
904-
self._handle_non_200_response(exc, {})
813+
_handle_non_200_response(exc, {})
905814

906815
return False
907816

@@ -926,7 +835,7 @@ def view_exists(self, identifier: Union[str, Identifier]) -> bool:
926835
try:
927836
response.raise_for_status()
928837
except HTTPError as exc:
929-
self._handle_non_200_response(exc, {})
838+
_handle_non_200_response(exc, {})
930839

931840
return False
932841

@@ -938,4 +847,4 @@ def drop_view(self, identifier: Union[str]) -> None:
938847
try:
939848
response.raise_for_status()
940849
except HTTPError as exc:
941-
self._handle_non_200_response(exc, {404: NoSuchViewError})
850+
_handle_non_200_response(exc, {404: NoSuchViewError})

0 commit comments

Comments
 (0)