-
Notifications
You must be signed in to change notification settings - Fork 111
Implements Token Federation for Python Driver #552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
56a854f
aedb3bf
d06672c
9aff811
299b5ae
10a5016
3bb9b3d
708c13b
a1e9894
00e015c
d538b75
4b48ac9
e8d4a48
5b74b60
edc6027
3613cb0
e87b52d
929191b
82d0be2
1e60750
de48411
d54ba93
aa2d1b9
34413f3
a93dd4b
76df22e
c37cd01
f2d4516
aeeca66
ae28649
541e82f
49eab2a
e6733cb
29f95f2
2e12935
e9de21a
efb9149
7ab4068
9fc4c0c
85d0cd9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
name: Token Federation Test | ||
|
||
# Tests token federation functionality with GitHub Actions OIDC tokens | ||
on: | ||
# Manual trigger with required inputs | ||
workflow_dispatch: | ||
inputs: | ||
databricks_host: | ||
description: 'Databricks host URL (e.g., example.cloud.databricks.com)' | ||
required: true | ||
databricks_http_path: | ||
description: 'Databricks HTTP path (e.g., /sql/1.0/warehouses/abc123)' | ||
required: true | ||
identity_federation_client_id: | ||
description: 'Identity federation client ID' | ||
required: true | ||
|
||
# Run on PRs that might affect token federation | ||
pull_request: | ||
branches: [main] | ||
paths: | ||
- 'src/databricks/sql/auth/**' | ||
- 'examples/token_federation_*.py' | ||
- 'tests/token_federation/**' | ||
- '.github/workflows/token-federation-test.yml' | ||
|
||
# Run on push to main that affects token federation | ||
push: | ||
branches: [main] | ||
paths: | ||
- 'src/databricks/sql/auth/**' | ||
- 'examples/token_federation_*.py' | ||
- 'tests/token_federation/**' | ||
- '.github/workflows/token-federation-test.yml' | ||
|
||
permissions: | ||
id-token: write # Required for GitHub OIDC token | ||
contents: read | ||
|
||
jobs: | ||
test-token-federation: | ||
name: Test Token Federation | ||
runs-on: | ||
group: databricks-protected-runner-group | ||
labels: linux-ubuntu-latest | ||
|
||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v4 | ||
|
||
- name: Set up Python 3.9 | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: '3.9' | ||
cache: 'pip' | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -e . | ||
pip install pyarrow | ||
|
||
- name: Get GitHub OIDC token | ||
id: get-id-token | ||
uses: actions/github-script@v7 | ||
with: | ||
script: | | ||
const token = await core.getIDToken('https://github.com/databricks') | ||
core.setSecret(token) | ||
core.setOutput('token', token) | ||
|
||
- name: Test token federation with GitHub OIDC token | ||
env: | ||
DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} | ||
DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }} | ||
IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} | ||
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} | ||
run: python tests/token_federation/github_oidc_test.py |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,18 @@ | |
AuthProvider, | ||
AccessTokenAuthProvider, | ||
ExternalAuthProvider, | ||
CredentialsProvider, | ||
DatabricksOAuthProvider, | ||
) | ||
|
||
|
||
class AuthType(Enum): | ||
DATABRICKS_OAUTH = "databricks-oauth" | ||
AZURE_OAUTH = "azure-oauth" | ||
# TODO: Token federation should be a feature that works with different auth types, | ||
# not an auth type itself. This will be refactored in a future change. | ||
# We will add a use_token_federation flag that can be used with any auth type. | ||
TOKEN_FEDERATION = "token-federation" | ||
# other supported types (access_token) can be inferred | ||
# we can add more types as needed later | ||
|
||
|
@@ -29,6 +34,7 @@ def __init__( | |
tls_client_cert_file: Optional[str] = None, | ||
oauth_persistence=None, | ||
credentials_provider=None, | ||
identity_federation_client_id: Optional[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is for workload identity federation flow? |
||
): | ||
self.hostname = hostname | ||
self.access_token = access_token | ||
|
@@ -40,11 +46,64 @@ def __init__( | |
self.tls_client_cert_file = tls_client_cert_file | ||
self.oauth_persistence = oauth_persistence | ||
self.credentials_provider = credentials_provider | ||
self.identity_federation_client_id = identity_federation_client_id | ||
|
||
|
||
def get_auth_provider(cfg: ClientContext): | ||
""" | ||
Get an appropriate auth provider based on the provided configuration. | ||
|
||
Token Federation Support: | ||
----------------------- | ||
Currently, token federation is implemented as a separate auth type, but the goal is to | ||
refactor it as a feature that can work with any auth type. The current implementation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does it mean that in future both will be supported? separate 'use_token_federation' as well as separate auth type? Deprecating anything which is released in a client is not easy. |
||
is maintained for backward compatibility while the refactoring is planned. | ||
|
||
Future refactoring will introduce a `use_token_federation` flag that can be combined | ||
with any auth type to enable token federation. | ||
|
||
Args: | ||
cfg: The client context containing configuration parameters | ||
|
||
Returns: | ||
An appropriate AuthProvider instance | ||
|
||
Raises: | ||
RuntimeError: If no valid authentication settings are provided | ||
""" | ||
# If credentials_provider is explicitly provided | ||
if cfg.credentials_provider: | ||
# If token federation is enabled and credentials provider is provided, | ||
# wrap the credentials provider with DatabricksTokenFederationProvider | ||
if cfg.auth_type == AuthType.TOKEN_FEDERATION.value: | ||
from databricks.sql.auth.token_federation import ( | ||
DatabricksTokenFederationProvider, | ||
) | ||
|
||
federation_provider = DatabricksTokenFederationProvider( | ||
cfg.credentials_provider, | ||
cfg.hostname, | ||
cfg.identity_federation_client_id, | ||
) | ||
return ExternalAuthProvider(federation_provider) | ||
|
||
# If not token federation, just use the credentials provider directly | ||
return ExternalAuthProvider(cfg.credentials_provider) | ||
|
||
# If we don't have a credentials provider but have token federation auth type with access token | ||
if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: | ||
# Create a simple credentials provider and wrap it with token federation provider | ||
from databricks.sql.auth.token_federation import ( | ||
DatabricksTokenFederationProvider, | ||
SimpleCredentialsProvider, | ||
) | ||
|
||
simple_provider = SimpleCredentialsProvider(cfg.access_token) | ||
federation_provider = DatabricksTokenFederationProvider( | ||
simple_provider, cfg.hostname, cfg.identity_federation_client_id | ||
) | ||
return ExternalAuthProvider(federation_provider) | ||
|
||
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: | ||
assert cfg.oauth_redirect_port_range is not None | ||
assert cfg.oauth_client_id is not None | ||
|
@@ -102,6 +161,27 @@ def get_client_id_and_redirect_port(use_azure_auth: bool): | |
|
||
|
||
def get_python_sql_connector_auth_provider(hostname: str, **kwargs): | ||
""" | ||
Get an auth provider for the Python SQL connector. | ||
|
||
This function is the main entry point for authentication in the SQL connector. | ||
It processes the parameters and creates an appropriate auth provider. | ||
|
||
TODO: Future refactoring needed: | ||
1. Add a use_token_federation flag that can be combined with any auth type | ||
2. Remove TOKEN_FEDERATION as an auth_type while maintaining backward compatibility | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will be hard once introduced |
||
3. Create a token federation wrapper that can wrap any existing auth provider | ||
|
||
Args: | ||
hostname: The Databricks server hostname | ||
**kwargs: Additional configuration parameters | ||
|
||
Returns: | ||
An appropriate AuthProvider instance | ||
|
||
Raises: | ||
ValueError: If username/password authentication is attempted (no longer supported) | ||
""" | ||
auth_type = kwargs.get("auth_type") | ||
(client_id, redirect_port_range) = get_client_id_and_redirect_port( | ||
auth_type == AuthType.AZURE_OAUTH.value | ||
|
@@ -125,5 +205,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): | |
else redirect_port_range, | ||
oauth_persistence=kwargs.get("experimental_oauth_persistence"), | ||
credentials_provider=kwargs.get("credentials_provider"), | ||
identity_federation_client_id=kwargs.get("identity_federation_client_id"), | ||
) | ||
return get_auth_provider(cfg) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import logging | ||
import requests | ||
from typing import Optional | ||
from urllib.parse import urlparse | ||
|
||
from databricks.sql.auth.endpoint import ( | ||
get_oauth_endpoints, | ||
infer_cloud_from_host, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class OIDCDiscoveryUtil: | ||
""" | ||
Utility class for OIDC discovery operations. | ||
|
||
This class handles discovery of OIDC endpoints through standard | ||
discovery mechanisms, with fallback to default endpoints if needed. | ||
""" | ||
|
||
# Standard token endpoint path for Databricks workspaces | ||
DEFAULT_TOKEN_PATH = "oidc/v1/token" | ||
|
||
@staticmethod | ||
def discover_token_endpoint(hostname: str) -> str: | ||
""" | ||
Get the token endpoint for the given Databricks hostname. | ||
|
||
For Databricks workspaces, the token endpoint is always at host/oidc/v1/token. | ||
|
||
Args: | ||
hostname: The hostname to get token endpoint for | ||
|
||
Returns: | ||
str: The token endpoint URL | ||
""" | ||
# Format the hostname and return the standard endpoint | ||
hostname = OIDCDiscoveryUtil.format_hostname(hostname) | ||
token_endpoint = f"{hostname}{OIDCDiscoveryUtil.DEFAULT_TOKEN_PATH}" | ||
logger.info(f"Using token endpoint: {token_endpoint}") | ||
return token_endpoint | ||
|
||
@staticmethod | ||
def format_hostname(hostname: str) -> str: | ||
""" | ||
Format hostname to ensure it has proper https:// prefix and trailing slash. | ||
|
||
Args: | ||
hostname: The hostname to format | ||
|
||
Returns: | ||
str: The formatted hostname | ||
""" | ||
if not hostname.startswith("https://"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if someone has given as http:// |
||
hostname = f"https://{hostname}" | ||
if not hostname.endswith("/"): | ||
hostname = f"{hostname}/" | ||
return hostname | ||
|
||
|
||
def is_same_host(url1: str, url2: str) -> bool: | ||
""" | ||
Check if two URLs have the same host. | ||
""" | ||
try: | ||
if not url1.startswith(("http://", "https://")): | ||
url1 = f"https://{url1}" | ||
if not url2.startswith(("http://", "https://")): | ||
url2 = f"https://{url2}" | ||
parsed1 = urlparse(url1) | ||
parsed2 = urlparse(url2) | ||
return parsed1.netloc.lower() == parsed2.netloc.lower() | ||
except Exception: | ||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
Token class for authentication tokens with expiry handling. | ||
""" | ||
|
||
from datetime import datetime, timezone, timedelta | ||
from typing import Optional | ||
|
||
|
||
class Token: | ||
""" | ||
Represents an OAuth token with expiry information. | ||
|
||
This class handles token state including expiry calculation. | ||
""" | ||
|
||
# Minimum time buffer before expiry to consider a token still valid (in seconds) | ||
MIN_VALIDITY_BUFFER = 10 | ||
|
||
def __init__( | ||
self, | ||
access_token: str, | ||
token_type: str, | ||
refresh_token: str = "", | ||
expiry: Optional[datetime] = None, | ||
): | ||
""" | ||
Initialize a Token object. | ||
|
||
Args: | ||
access_token: The access token string | ||
token_type: The token type (usually "Bearer") | ||
refresh_token: Optional refresh token | ||
expiry: Token expiry datetime, must be provided | ||
|
||
Raises: | ||
ValueError: If no expiry is provided | ||
""" | ||
self.access_token = access_token | ||
self.token_type = token_type | ||
self.refresh_token = refresh_token | ||
|
||
# Ensure we have an expiry time | ||
if expiry is None: | ||
raise ValueError("Token expiry must be provided") | ||
|
||
# Ensure expiry is timezone-aware | ||
if expiry.tzinfo is None: | ||
# Convert naive datetime to aware datetime | ||
self.expiry = expiry.replace(tzinfo=timezone.utc) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does it mean that without timezone will be treated as UTC time? |
||
else: | ||
self.expiry = expiry | ||
|
||
def is_valid(self) -> bool: | ||
""" | ||
Check if the token is valid (has at least MIN_VALIDITY_BUFFER seconds before expiry). | ||
|
||
Returns: | ||
bool: True if the token is valid, False otherwise | ||
""" | ||
buffer = timedelta(seconds=self.MIN_VALIDITY_BUFFER) | ||
return datetime.now(tz=timezone.utc) + buffer < self.expiry | ||
|
||
def __str__(self) -> str: | ||
"""Return the token as a string in the format used for Authorization headers.""" | ||
return f"{self.token_type} {self.access_token}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. headers are typically of form "scheme token", where scheme can be like "bearer", "api key" etc. The token type more looks like our internal concepts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this not run as part of release process? What if in a release there is no change in here listed code folders? We should still run this as part of release if not on every PR.