Skip to content

Add Support for Custom AuthManager implementation #2055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion pyiceberg/catalog/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class IdentifierKind(Enum):
SIGV4_SERVICE = "rest.signing-name"
OAUTH2_SERVER_URI = "oauth2-server-uri"
SNAPSHOT_LOADING_MODE = "snapshot-loading-mode"
AUTH = "auth"

NAMESPACE_SEPARATOR = b"\x1f".decode(UTF8)

Expand Down Expand Up @@ -247,7 +248,19 @@ def _create_session(self) -> Session:
elif ssl_client_cert := ssl_client.get(CERT):
session.cert = ssl_client_cert

session.auth = AuthManagerAdapter(self._create_legacy_oauth2_auth_manager(session))
if auth_config := self.properties.get(AUTH):
# set up auth_manager based on the properties
auth_type = auth_config.get("type")
if auth_type is None:
raise ValueError("auth.type must be defined")
auth_type_config = auth_config.get(auth_type, {})
if auth_impl := auth_config.get("impl"):
session.auth = AuthManagerAdapter(AuthManagerFactory.create(auth_impl, auth_type_config))
else:
session.auth = AuthManagerAdapter(AuthManagerFactory.create(auth_type, auth_type_config))
else:
session.auth = AuthManagerAdapter(self._create_legacy_oauth2_auth_manager(session))

# Set HTTP headers
self._config_headers(session)

Expand Down
87 changes: 87 additions & 0 deletions pyiceberg/catalog/rest/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

import base64
import importlib
import threading
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type

import requests
from requests import HTTPError, PreparedRequest, Session
from requests.auth import AuthBase

Expand All @@ -42,11 +45,15 @@ def auth_header(self) -> Optional[str]:


class NoopAuthManager(AuthManager):
"""Auth Manager implementation with no auth."""

def auth_header(self) -> Optional[str]:
return None


class BasicAuthManager(AuthManager):
"""AuthManager implementation that supports basic password auth."""

def __init__(self, username: str, password: str):
credentials = f"{username}:{password}"
self._token = base64.b64encode(credentials.encode()).decode()
Expand All @@ -56,6 +63,12 @@ def auth_header(self) -> str:


class LegacyOAuth2AuthManager(AuthManager):
"""Legacy OAuth2 AuthManager implementation.

This class exists for backward compatibility, and will be removed in
PyIceberg 1.0.0 in favor of OAuth2AuthManager.
"""

_session: Session
_auth_url: Optional[str]
_token: Optional[str]
Expand Down Expand Up @@ -109,6 +122,80 @@ def auth_header(self) -> str:
return f"Bearer {self._token}"


class OAuth2TokenProvider:
"""Thread-safe OAuth2 token provider with token refresh support."""

client_id: str
client_secret: str
token_url: str
scope: Optional[str]
refresh_margin: int
expires_in: Optional[int]

_token: Optional[str]
_expires_at: int
_lock: threading.Lock

def __init__(
self,
client_id: str,
client_secret: str,
token_url: str,
scope: Optional[str] = None,
refresh_margin: int = 60,
expires_in: Optional[int] = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.token_url = token_url
self.scope = scope
self.refresh_margin = refresh_margin
self.expires_in = expires_in

self._token = None
self._expires_at = 0
self._lock = threading.Lock()

def _refresh_token(self) -> None:
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
}
if self.scope:
data["scope"] = self.scope

response = requests.post(self.token_url, data=data)
response.raise_for_status()
result = response.json()

self._token = result["access_token"]
expires_in = result.get("expires_in", self.expires_in)
if expires_in is None:
raise ValueError(
"The expiration time of the Token must be provided by the Server in the Access Token Response in `expired_in` field, or by the PyIceberg Client."
)
self._expires_at = time.time() + expires_in - self.refresh_margin

def get_token(self) -> str:
with self._lock:
if not self._token or time.time() >= self._expires_at:
self._refresh_token()
if self._token is None:
raise ValueError("Authorization token is None after refresh")
return self._token


class OAuth2AuthManager(AuthManager):
"""Auth Manager implementation that supports OAuth2 as defined in IETF RFC6749."""

def __init__(self, token_provider: OAuth2TokenProvider):
self.token_provider = token_provider

def auth_header(self) -> str:
return f"Bearer {self.token_provider.get_token()}"


class AuthManagerAdapter(AuthBase):
"""A `requests.auth.AuthBase` adapter that integrates an `AuthManager` into a `requests.Session` to automatically attach the appropriate Authorization header to every request.

Expand Down
36 changes: 36 additions & 0 deletions tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,42 @@ def test_request_session_with_ssl_client_cert() -> None:
assert "Could not find the TLS certificate file, invalid path: path_to_client_cert" in str(e.value)


def test_rest_catalog_with_basic_auth_type() -> None:
# Given
catalog_properties = {
"uri": TEST_URI,
"auth": {
"type": "basic",
"basic": {
"username": "one",
},
},
}
with pytest.raises(TypeError) as e:
# Missing namespace
RestCatalog("rest", **catalog_properties) # type: ignore
assert "BasicAuthManager.__init__() missing 1 required positional argument: 'password'" in str(e.value)


def test_rest_catalog_with_auth_impl() -> None:
# Given
catalog_properties = {
"uri": TEST_URI,
"auth": {
"type": "custom",
"impl": "dummy.nonexistent.package",
"custom": {
"property1": "one",
"property2": "two",
},
},
}
with pytest.raises(ValueError) as e:
# Missing namespace
RestCatalog("rest", **catalog_properties) # type: ignore
assert "Could not load AuthManager class for 'dummy.nonexistent.package'" in str(e.value)


EXAMPLE_ENV = {"PYICEBERG_CATALOG__PRODUCTION__URI": TEST_URI}


Expand Down
Loading