Skip to content

Commit 8319854

Browse files
authored
[PECO-1414] Support Databricks native OAuth in Azure (#351)
* [PECO-1414] Support Databricks InHouse OAuth in Azure Signed-off-by: Jacky Hu <[email protected]>
1 parent 1b469c0 commit 8319854

File tree

6 files changed

+243
-98
lines changed

6 files changed

+243
-98
lines changed

src/databricks/sql/auth/auth.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
ExternalAuthProvider,
99
DatabricksOAuthProvider,
1010
)
11-
from databricks.sql.auth.endpoint import infer_cloud_from_host, CloudType
12-
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
1311

1412

1513
class AuthType(Enum):
1614
DATABRICKS_OAUTH = "databricks-oauth"
15+
AZURE_OAUTH = "azure-oauth"
1716
# other supported types (access_token, user/pass) can be inferred
1817
# we can add more types as needed later
1918

@@ -51,7 +50,7 @@ def __init__(
5150
def get_auth_provider(cfg: ClientContext):
5251
if cfg.credentials_provider:
5352
return ExternalAuthProvider(cfg.credentials_provider)
54-
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value:
53+
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
5554
assert cfg.oauth_redirect_port_range is not None
5655
assert cfg.oauth_client_id is not None
5756
assert cfg.oauth_scopes is not None
@@ -62,6 +61,7 @@ def get_auth_provider(cfg: ClientContext):
6261
cfg.oauth_redirect_port_range,
6362
cfg.oauth_client_id,
6463
cfg.oauth_scopes,
64+
cfg.auth_type,
6565
)
6666
elif cfg.access_token is not None:
6767
return AccessTokenAuthProvider(cfg.access_token)
@@ -87,20 +87,22 @@ def normalize_host_name(hostname: str):
8787
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}"
8888

8989

90-
def get_client_id_and_redirect_port(hostname: str):
91-
cloud_type = infer_cloud_from_host(hostname)
90+
def get_client_id_and_redirect_port(use_azure_auth: bool):
9291
return (
9392
(PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE)
94-
if cloud_type == CloudType.AWS or cloud_type == CloudType.GCP
93+
if not use_azure_auth
9594
else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE)
9695
)
9796

9897

9998
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
100-
(client_id, redirect_port_range) = get_client_id_and_redirect_port(hostname)
99+
auth_type = kwargs.get("auth_type")
100+
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
101+
auth_type == AuthType.AZURE_OAUTH.value
102+
)
101103
cfg = ClientContext(
102104
hostname=normalize_host_name(hostname),
103-
auth_type=kwargs.get("auth_type"),
105+
auth_type=auth_type,
104106
access_token=kwargs.get("access_token"),
105107
username=kwargs.get("_username"),
106108
password=kwargs.get("_password"),

src/databricks/sql/auth/authenticators.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def add_headers(self, request_headers: Dict[str, str]):
1818

1919
HeaderFactory = Callable[[], Dict[str, str]]
2020

21+
2122
# In order to keep compatibility with SDK
2223
class CredentialsProvider(abc.ABC):
2324
"""CredentialsProvider is the protocol (call-side interface)
@@ -69,16 +70,13 @@ def __init__(
6970
redirect_port_range: List[int],
7071
client_id: str,
7172
scopes: List[str],
73+
auth_type: str = "databricks-oauth",
7274
):
7375
try:
74-
cloud_type = infer_cloud_from_host(hostname)
75-
if not cloud_type:
76-
raise NotImplementedError("Cannot infer the cloud type from hostname")
77-
78-
idp_endpoint = get_oauth_endpoints(cloud_type)
76+
idp_endpoint = get_oauth_endpoints(hostname, auth_type == "azure-oauth")
7977
if not idp_endpoint:
8078
raise NotImplementedError(
81-
f"OAuth is not supported for cloud ${cloud_type.value}"
79+
f"OAuth is not supported for host ${hostname}"
8280
)
8381

8482
# Convert to the corresponding scopes in the corresponding IdP

src/databricks/sql/auth/endpoint.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#
22
# It implements all the cloud specific OAuth configuration/metadata
33
#
4-
# Azure: It uses AAD
4+
# Azure: It uses Databricks internal IdP or Azure AD
55
# AWS: It uses Databricks internal IdP
6-
# GCP: Not support yet
6+
# GCP: It uses Databricks internal IdP
77
#
88
from abc import ABC, abstractmethod
99
from enum import Enum
@@ -37,6 +37,9 @@ class CloudType(Enum):
3737
]
3838
DATABRICKS_GCP_DOMAINS = [".gcp.databricks.com"]
3939

40+
# Domain supported by Databricks InHouse OAuth
41+
DATABRICKS_OAUTH_AZURE_DOMAINS = [".azuredatabricks.net"]
42+
4043

4144
# Infer cloud type from Databricks SQL instance hostname
4245
def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
@@ -53,6 +56,14 @@ def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
5356
return None
5457

5558

59+
def is_supported_databricks_oauth_host(hostname: str) -> bool:
60+
host = hostname.lower().replace("https://", "").split("/")[0]
61+
domains = (
62+
DATABRICKS_AWS_DOMAINS + DATABRICKS_GCP_DOMAINS + DATABRICKS_OAUTH_AZURE_DOMAINS
63+
)
64+
return any(e for e in domains if host.endswith(e))
65+
66+
5667
def get_databricks_oidc_url(hostname: str):
5768
maybe_scheme = "https://" if not hostname.startswith("https://") else ""
5869
maybe_trailing_slash = "/" if not hostname.endswith("/") else ""
@@ -112,10 +123,18 @@ def get_openid_config_url(self, hostname: str):
112123
return f"{idp_url}/.well-known/oauth-authorization-server"
113124

114125

115-
def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]:
116-
if cloud == CloudType.AWS or cloud == CloudType.GCP:
126+
def get_oauth_endpoints(
127+
hostname: str, use_azure_auth: bool
128+
) -> Optional[OAuthEndpointCollection]:
129+
cloud = infer_cloud_from_host(hostname)
130+
131+
if cloud in [CloudType.AWS, CloudType.GCP]:
117132
return InHouseOAuthEndpointCollection()
118133
elif cloud == CloudType.AZURE:
119-
return AzureOAuthEndpointCollection()
134+
return (
135+
InHouseOAuthEndpointCollection()
136+
if is_supported_databricks_oauth_host(hostname) and not use_azure_auth
137+
else AzureOAuthEndpointCollection()
138+
)
120139
else:
121140
return None

src/databricks/sql/client.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def __init__(
9696
legacy purposes and will be deprecated in a future release. When this parameter is `True` you will see
9797
a warning log message. To suppress this log message, set `use_inline_params="silent"`.
9898
auth_type: `str`, optional
99-
`databricks-oauth` : to use oauth with fine-grained permission scopes, set to `databricks-oauth`.
99+
`databricks-oauth` : to use Databricks OAuth with fine-grained permission scopes, set to `databricks-oauth`.
100+
`azure-oauth` : to use Microsoft Entra ID OAuth flow, set to `azure-oauth`.
100101
101102
oauth_client_id: `str`, optional
102103
custom oauth client_id. If not specified, it will use the built-in client_id of databricks-sql-python.
@@ -107,9 +108,9 @@ def __init__(
107108
108109
experimental_oauth_persistence: configures preferred storage for persisting oauth tokens.
109110
This has to be a class implementing `OAuthPersistence`.
110-
When `auth_type` is set to `databricks-oauth` without persisting the oauth token in a persistence storage
111-
the oauth tokens will only be maintained in memory and if the python process restarts the end user
112-
will have to login again.
111+
When `auth_type` is set to `databricks-oauth` or `azure-oauth` without persisting the oauth token in a
112+
persistence storage the oauth tokens will only be maintained in memory and if the python process
113+
restarts the end user will have to login again.
113114
Note this is beta (private preview)
114115
115116
For persisting the oauth token in a prod environment you should subclass and implement OAuthPersistence

0 commit comments

Comments
 (0)