Skip to content
Open

wip #1259

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ test:
pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests

integration:
# TODO: Remove dual-run once experimental_is_unified_host flag is removed and unified mode becomes default
@echo "Running integration tests in unified mode..."
DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST=true pytest -n auto -m 'integration and not benchmark' --reruns 4 --dist loadgroup --cov=databricks --cov-append --cov-report html tests
@echo "Running integration tests in legacy mode..."
pytest -n auto -m 'integration and not benchmark' --reruns 4 --dist loadgroup --cov=databricks --cov-report html tests

benchmark:
Expand Down
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## Release v0.88.0

### New Features and Improvements
* Remove cloud type restrictions from Azure/GCP credential providers. Azure and GCP authentication now works with any Databricks host when credentials are properly configured,enabling authentication against cloud-agnostic endpoints such as aliased hosts.
* Add support for legacy Profiles in Unified Mode. It is now possible to use any host in Unified Mode.

### Security

Expand Down
91 changes: 67 additions & 24 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DatabricksEnvironment, get_environment_for_hostname)
from .oauth import (OidcEndpoints, Token, get_account_endpoints,
get_azure_entra_id_workspace_endpoints,
get_unified_endpoints, get_workspace_endpoints)
get_workspace_endpoints)

logger = logging.getLogger("databricks.sdk")

Expand All @@ -46,7 +46,10 @@ def __get__(self, cfg: "Config", owner):
return cfg._inner.get(self.name, None)

def __set__(self, cfg: "Config", value: any):
cfg._inner[self.name] = self.transform(value)
if value is None:
cfg._inner.pop(self.name, None)
else:
cfg._inner[self.name] = self.transform(value)

def __repr__(self) -> str:
return f"<ConfigAttribute '{self.name}' {self.transform.__name__}>"
Expand Down Expand Up @@ -280,6 +283,7 @@ def __init__(
self.databricks_environment = kwargs["databricks_environment"]
del kwargs["databricks_environment"]
self._clock = clock if clock is not None else RealClock()

try:
self._set_inner_config(kwargs)
self._load_from_env()
Expand All @@ -288,6 +292,9 @@ def __init__(
self._validate()
self.init_auth()
self._init_product(product, product_version)
# Extract the workspace ID for legacy profiles. This is extracted from an API call.
if not self.workspace_id and not self.account_id and self.experimental_is_unified_host:
self.workspace_id = self._fetch_workspace_id()
except ValueError as e:
message = self.wrap_debug_info(str(e))
raise ValueError(message) from e
Expand Down Expand Up @@ -369,15 +376,15 @@ def environment(self) -> DatabricksEnvironment:
def is_azure(self) -> bool:
if self.azure_workspace_resource_id:
return True
return self.environment.cloud == Cloud.AZURE
return self.environment is not None and self.environment.cloud == Cloud.AZURE

@property
def is_gcp(self) -> bool:
return self.environment.cloud == Cloud.GCP
return self.environment is not None and self.environment.cloud == Cloud.GCP

@property
def is_aws(self) -> bool:
return self.environment.cloud == Cloud.AWS
return self.environment is not None and self.environment.cloud == Cloud.AWS

@property
def host_type(self) -> HostType:
Expand All @@ -400,7 +407,9 @@ def host_type(self) -> HostType:

@property
def client_type(self) -> ClientType:
"""Determine the type of client configuration.
"""
[Deprecated] Deprecated. Use host_type instead. Some hosts can support both account and workspace clients.
Determine the type of client configuration.

This is separate from host_type. For example, a unified host can support both
workspace and account client types.
Expand All @@ -419,25 +428,24 @@ def client_type(self) -> ClientType:
return ClientType.WORKSPACE

if host_type == HostType.UNIFIED:
if not self.account_id:
raise ValueError("Unified host requires account_id to be set")
if self.workspace_id:
return ClientType.WORKSPACE
return ClientType.ACCOUNT
if self.account_id:
return ClientType.ACCOUNT
# Legacy workspace hosts don't have a workspace_id until AFTER the auth is resolved.
return ClientType.WORKSPACE

# Default to workspace for backward compatibility
return ClientType.WORKSPACE

@property
def is_account_client(self) -> bool:
"""[Deprecated] Use host_type or client_type instead.
"""[Deprecated] Use host_type instead.

Determines if this is an account client based on the host URL.
Determines if this config is compatible with an account client based on the host URL and account_id.
"""
if self.experimental_is_unified_host:
raise ValueError(
"is_account_client cannot be used with unified hosts; use host_type or client_type instead"
)
return self.account_id
if not self.host:
return False
return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.")
Expand Down Expand Up @@ -505,10 +513,8 @@ def databricks_oidc_endpoints(self) -> Optional[OidcEndpoints]:
return None

# Handle unified hosts
if self.host_type == HostType.UNIFIED:
if not self.account_id:
raise ValueError("Unified host requires account_id to be set for OAuth endpoints")
return get_unified_endpoints(self.host, self.account_id)
if self.experimental_is_unified_host:
return self._experimental_oidc_discovery()

# Handle traditional account hosts
if self.host_type == HostType.ACCOUNTS and self.account_id:
Expand All @@ -517,6 +523,32 @@ def databricks_oidc_endpoints(self) -> Optional[OidcEndpoints]:
# Default to workspace endpoints
return get_workspace_endpoints(self.host)

def _experimental_oidc_discovery(self) -> Optional[OidcEndpoints]:
"""[Experimental] Discover OIDC endpoints for Databricks OAuth.

This method discovers the OIDC endpoints for Databricks OAuth by making a request to the
multiple paths.
This is not to be used for production purposes.
It is only to be used for testing and development purposes until a unified OIDC endpoint is available.
"""

# DO NOT REMOVE THIS ERROR. THIS IS USED TO ENFORCE THAT THE EXPERIMENTAL IS UNIFIED HOST FLAG IS SET.
# THIS WHOLE METHOD SHOULD BE REMOVED WHEN THE EXPERIMENTAL IS UNIFIED HOST FLAG IS REMOVED.
if not self.experimental_is_unified_host:
raise ValueError(
"experimental_oidc_discovery is only supported with the experimental_is_unified_host flag set"
)
if self.account_id:
try:
return get_account_endpoints(self.host, self.account_id)
except Exception as e:
logger.warning(f"Failed to discover OIDC endpoints for account {self.account_id}: {e}")
try:
return get_workspace_endpoints(self.host)
except Exception as e:
logger.warning(f"Failed to discover OIDC endpoints for workspace {self.workspace_id}: {e}")
raise ValueError("Failed to discover OIDC endpoints")

@property
def oidc_endpoints(self) -> Optional[OidcEndpoints]:
"""[DEPRECATED] Get OIDC endpoints with automatic Azure detection (deprecated).
Expand Down Expand Up @@ -574,12 +606,9 @@ def sql_http_path(self) -> Optional[str]:
return None
if self.cluster_id and self.warehouse_id:
raise ValueError("cannot have both cluster_id and warehouse_id")
headers = self.authenticate()
headers["User-Agent"] = f"{self.user_agent} sdk-feature/sql-http-path"
if self.cluster_id:
response = requests.get(f"{self.host}/api/2.0/preview/scim/v2/Me", headers=headers)
# get workspace ID from the response header
workspace_id = response.headers.get("x-databricks-org-id")
# Reuse cached workspace_id or fetch it
workspace_id = self.workspace_id or self._fetch_workspace_id()
return f"sql/protocolv1/o/{workspace_id}/{self.cluster_id}"
if self.warehouse_id:
return f"/sql/1.0/warehouses/{self.warehouse_id}"
Expand Down Expand Up @@ -621,7 +650,7 @@ def load_azure_tenant_id(self):
"""[Internal] Load the Azure tenant ID from the Azure Databricks login page.

If the tenant ID is already set, this method does nothing."""
if not self.is_azure or self.azure_tenant_id is not None or self.host is None:
if self.azure_tenant_id is not None or self.host is None:
return
login_url = f"{self.host}/aad/auth"
logger.debug(f"Loading tenant ID from {login_url}")
Expand Down Expand Up @@ -770,3 +799,17 @@ def copy(self):
def deep_copy(self):
"""Creates a deep copy of the config object."""
return copy.deepcopy(self)

# The code below is used to support legacy hosts.
def _fetch_workspace_id(self) -> Optional[str]:
"""Fetch the workspace ID from the host."""
try:
headers = self.authenticate()
headers["User-Agent"] = f"{self.user_agent} sdk-feature/sql-http-path"
response = requests.get(f"{self.host}/api/2.0/preview/scim/v2/Me", headers=headers)
response.raise_for_status()
# get workspace ID from the response header
return response.headers.get("x-databricks-org-id")
except Exception as e:
logger.debug(f"Failed to fetch workspace ID: {e}")
return None
23 changes: 8 additions & 15 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], oauth.

@oauth_credentials_strategy(
"azure-client-secret",
["is_azure", "azure_client_id", "azure_client_secret"],
["azure_client_id", "azure_client_secret"],
)
def azure_service_principal(cfg: "Config") -> CredentialsProvider:
"""Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
Expand Down Expand Up @@ -439,9 +439,9 @@ def _oidc_credentials_provider(

# Determine the audience for token exchange
audience = cfg.token_audience
if audience is None and cfg.client_type == ClientType.ACCOUNT:
if audience is None and cfg.account_id:
audience = cfg.account_id
if audience is None and cfg.client_type != ClientType.ACCOUNT:
if audience is None and not cfg.account_id:
audience = cfg.databricks_oidc_endpoints.token_endpoint

# Try to get an OIDC token. If no supplier returns a token, we cannot use this authentication mode.
Expand Down Expand Up @@ -513,17 +513,14 @@ def azure_devops_oidc(cfg: "Config") -> Optional[CredentialsProvider]:
)


# Azure Client ID is the minimal thing we need, as otherwise we get AADSTS700016: Application with
# identifier 'https://token.actions.githubusercontent.com' was not found in the directory '...'.
@oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
# not in GitHub actions
return None

# Client ID is the minimal thing we need, as otherwise we get AADSTS700016: Application with
# identifier 'https://token.actions.githubusercontent.com' was not found in the directory '...'.
if not cfg.is_azure:
return None

token = oidc_token_supplier.GitHubOIDCTokenSupplier().get_oidc_token("api://AzureADTokenExchange")
if not token:
return None
Expand Down Expand Up @@ -572,8 +569,6 @@ def token() -> oauth.Token:

@oauth_credentials_strategy("google-credentials", ["host", "google_credentials"])
def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]:
if not cfg.is_gcp:
return None
# Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
# Obtain the id token by providing the json file path and target audience.
if os.path.isfile(cfg.google_credentials):
Expand All @@ -598,7 +593,7 @@ def token() -> oauth.Token:
def refreshed_headers() -> Dict[str, str]:
credentials.refresh(request)
headers = {"Authorization": f"Bearer {credentials.token}"}
if cfg.client_type == ClientType.ACCOUNT:
if cfg.account_id:
gcp_credentials.refresh(request)
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
return headers
Expand All @@ -608,8 +603,6 @@ def refreshed_headers() -> Dict[str, str]:

@oauth_credentials_strategy("google-id", ["host", "google_service_account"])
def google_id(cfg: "Config") -> Optional[CredentialsProvider]:
if not cfg.is_gcp:
return None
credentials, _project_id = google.auth.default()

# Create the impersonated credential.
Expand Down Expand Up @@ -639,7 +632,7 @@ def token() -> oauth.Token:
def refreshed_headers() -> Dict[str, str]:
id_creds.refresh(request)
headers = {"Authorization": f"Bearer {id_creds.token}"}
if cfg.client_type == ClientType.ACCOUNT:
if cfg.account_id:
gcp_impersonated_credentials.refresh(request)
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
return headers
Expand Down Expand Up @@ -814,7 +807,7 @@ def get_subscription(cfg: "Config") -> Optional[str]:
return components[2]


@credentials_strategy("azure-cli", ["is_azure"])
@credentials_strategy("azure-cli", ["effective_azure_login_app_id"])
def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]:
"""Adds refreshed OAuth token granted by `az login` command to every request."""
cfg.load_azure_tenant_id()
Expand Down
13 changes: 9 additions & 4 deletions databricks/sdk/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,20 @@ class Cloud(Enum):
AWS = "AWS"
AZURE = "AZURE"
GCP = "GCP"
UNKNOWN = "UNKNOWN"


@dataclass
class DatabricksEnvironment:
cloud: Cloud
dns_zone: str
dns_zone: Optional[str] = None
azure_application_id: Optional[str] = None
azure_environment: Optional[AzureEnvironment] = None

def deployment_url(self, name: str) -> str:
# Unified environments do not have a separate workspace host.
if self.dns_zone is None:
raise ValueError("This environment does not support deployment URLs.")
return f"https://{name}{self.dns_zone}"

@property
Expand All @@ -70,13 +74,13 @@ def azure_active_directory_endpoint(self) -> Optional[str]:
return self.azure_environment.active_directory_endpoint


DEFAULT_ENVIRONMENT = DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.com")
DEFAULT_ENVIRONMENT = DatabricksEnvironment(Cloud.UNKNOWN, None)

ALL_ENVS = [
DatabricksEnvironment(Cloud.AWS, ".dev.databricks.com"),
DatabricksEnvironment(Cloud.AWS, ".staging.cloud.databricks.com"),
DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.us"),
DEFAULT_ENVIRONMENT,
DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.com"),
DatabricksEnvironment(
Cloud.AZURE,
".dev.azuredatabricks.net",
Expand Down Expand Up @@ -110,13 +114,14 @@ def azure_active_directory_endpoint(self) -> Optional[str]:
DatabricksEnvironment(Cloud.GCP, ".dev.gcp.databricks.com"),
DatabricksEnvironment(Cloud.GCP, ".staging.gcp.databricks.com"),
DatabricksEnvironment(Cloud.GCP, ".gcp.databricks.com"),
DEFAULT_ENVIRONMENT,
]


def get_environment_for_hostname(hostname: Optional[str]) -> DatabricksEnvironment:
if not hostname:
return DEFAULT_ENVIRONMENT
for env in ALL_ENVS:
if hostname.endswith(env.dns_zone):
if env.dns_zone and hostname.endswith(env.dns_zone):
return env
return DEFAULT_ENVIRONMENT
4 changes: 2 additions & 2 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def a(env_or_skip) -> AccountClient:
_load_debug_env_if_runs_from_ide("account")
env_or_skip("CLOUD_ENV")
account_client = AccountClient()
if not account_client.config.is_account_client:
if not account_client.config.account_id:
pytest.skip("not Databricks Account client")
return account_client

Expand All @@ -75,7 +75,7 @@ def ucacct(env_or_skip) -> AccountClient:
_load_debug_env_if_runs_from_ide("ucacct")
env_or_skip("CLOUD_ENV")
account_client = AccountClient()
if not account_client.config.is_account_client:
if not account_client.config.account_id:
pytest.skip("not Databricks Account client")
if "TEST_METASTORE_ID" not in os.environ:
pytest.skip("not in Unity Catalog Workspace test env")
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import io
import json
import os
import re
import shutil
import subprocess
Expand Down Expand Up @@ -261,6 +262,9 @@ def test_wif_workspace(ucacct, env_or_skip, random):
permissions=[iam.WorkspacePermission.ADMIN],
)

# Clean env var
os.environ.pop("DATABRICKS_ACCOUNT_ID", None)

ws = WorkspaceClient(
host=workspace_url,
client_id=sp.application_id,
Expand Down
Loading
Loading