Skip to content

Commit 54e3769

Browse files
authored
[PECO-626] Support OAuth flow for Databricks Azure (databricks#86)
## Summary Support OAuth flow for Databricks Azure ## Background Some OAuth endpoints (e.g. Open ID Configuration) and scopes are different between Databricks Azure and AWS. Current code only supports OAuth flow on Databricks in AWS ## What changes are proposed in this pull request? - Change `OAuthManager` to decouple Databricks AWS specific configuration from OAuth flow - Add `sql/auth/endpoint.py` that implements cloud specific OAuth endpoint configuration - Change `DatabricksOAuthProvider` to work with the OAuth configurations in different Databricks cloud (AWS, Azure) - Add the corresponding unit tests
1 parent bbe539e commit 54e3769

File tree

8 files changed

+273
-24
lines changed

8 files changed

+273
-24
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## 2.6.x (Unreleased)
44

5+
- Add support for OAuth on Databricks Azure
6+
57
## 2.6.2 (2023-06-14)
68

79
- Fix: Retry GetOperationStatus requests for http errors

src/databricks/sql/auth/auth.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ExternalAuthProvider,
99
DatabricksOAuthProvider,
1010
)
11+
from databricks.sql.auth.endpoint import infer_cloud_from_host, CloudType
1112
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
1213

1314

@@ -75,7 +76,9 @@ def get_auth_provider(cfg: ClientContext):
7576

7677
PYSQL_OAUTH_SCOPES = ["sql", "offline_access"]
7778
PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python"
79+
PYSQL_OAUTH_AZURE_CLIENT_ID = "96eecda7-19ea-49cc-abb5-240097d554f5"
7880
PYSQL_OAUTH_REDIRECT_PORT_RANGE = list(range(8020, 8025))
81+
PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE = [8030]
7982

8083

8184
def normalize_host_name(hostname: str):
@@ -84,7 +87,16 @@ def normalize_host_name(hostname: str):
8487
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}"
8588

8689

90+
def get_client_id_and_redirect_port(hostname: str):
91+
return (
92+
(PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE)
93+
if infer_cloud_from_host(hostname) == CloudType.AWS
94+
else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE)
95+
)
96+
97+
8798
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
99+
(client_id, redirect_port_range) = get_client_id_and_redirect_port(hostname)
88100
cfg = ClientContext(
89101
hostname=normalize_host_name(hostname),
90102
auth_type=kwargs.get("auth_type"),
@@ -94,10 +106,10 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
94106
use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
95107
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
96108
oauth_scopes=PYSQL_OAUTH_SCOPES,
97-
oauth_client_id=kwargs.get("oauth_client_id") or PYSQL_OAUTH_CLIENT_ID,
109+
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
98110
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
99111
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
100-
else PYSQL_OAUTH_REDIRECT_PORT_RANGE,
112+
else redirect_port_range,
101113
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
102114
credentials_provider=kwargs.get("credentials_provider"),
103115
)

src/databricks/sql/auth/authenticators.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Callable, Dict, List
55

66
from databricks.sql.auth.oauth import OAuthManager
7+
from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host
78

89
# Private API: this is an evolving interface and it will change in the future.
910
# Please must not depend on it in your applications.
@@ -70,11 +71,26 @@ def __init__(
7071
scopes: List[str],
7172
):
7273
try:
74+
cloud_type = infer_cloud_from_host(hostname)
75+
if not cloud_type:
76+
raise NotImplementedError("Cannot infer the cloud type from hostname")
77+
78+
idp_endpoint = get_oauth_endpoints(cloud_type)
79+
if not idp_endpoint:
80+
raise NotImplementedError(
81+
f"OAuth is not supported for cloud ${cloud_type.value}"
82+
)
83+
84+
# Convert to the corresponding scopes in the corresponding IdP
85+
cloud_scopes = idp_endpoint.get_scopes_mapping(scopes)
86+
7387
self.oauth_manager = OAuthManager(
74-
port_range=redirect_port_range, client_id=client_id
88+
port_range=redirect_port_range,
89+
client_id=client_id,
90+
idp_endpoint=idp_endpoint,
7591
)
7692
self._hostname = hostname
77-
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes)
93+
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes)
7894
self._oauth_persistence = oauth_persistence
7995
self._client_id = client_id
8096
self._access_token = None

src/databricks/sql/auth/endpoint.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#
2+
# It implements all the cloud specific OAuth configuration/metadata
3+
#
4+
# Azure: It uses AAD
5+
# AWS: It uses Databricks internal IdP
6+
# GCP: Not support yet
7+
#
8+
from abc import ABC, abstractmethod
9+
from enum import Enum
10+
from typing import Optional, List
11+
import os
12+
13+
OIDC_REDIRECTOR_PATH = "oidc"
14+
15+
16+
class OAuthScope:
17+
OFFLINE_ACCESS = "offline_access"
18+
SQL = "sql"
19+
20+
21+
class CloudType(Enum):
22+
AWS = "aws"
23+
AZURE = "azure"
24+
25+
26+
DATABRICKS_AWS_DOMAINS = [".cloud.databricks.com", ".dev.databricks.com"]
27+
DATABRICKS_AZURE_DOMAINS = [
28+
".azuredatabricks.net",
29+
".databricks.azure.cn",
30+
".databricks.azure.us",
31+
]
32+
33+
34+
# Infer cloud type from Databricks SQL instance hostname
35+
def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
36+
# normalize
37+
host = hostname.lower().replace("https://", "").split("/")[0]
38+
39+
if any(e for e in DATABRICKS_AZURE_DOMAINS if host.endswith(e)):
40+
return CloudType.AZURE
41+
elif any(e for e in DATABRICKS_AWS_DOMAINS if host.endswith(e)):
42+
return CloudType.AWS
43+
else:
44+
return None
45+
46+
47+
def get_databricks_oidc_url(hostname: str):
48+
maybe_scheme = "https://" if not hostname.startswith("https://") else ""
49+
maybe_trailing_slash = "/" if not hostname.endswith("/") else ""
50+
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}{OIDC_REDIRECTOR_PATH}"
51+
52+
53+
class OAuthEndpointCollection(ABC):
54+
@abstractmethod
55+
def get_scopes_mapping(self, scopes: List[str]) -> List[str]:
56+
raise NotImplementedError()
57+
58+
# Endpoint for oauth2 authorization e.g https://idp.example.com/oauth2/v2.0/authorize
59+
@abstractmethod
60+
def get_authorization_url(self, hostname: str) -> str:
61+
raise NotImplementedError()
62+
63+
# Endpoint for well-known openid configuration e.g https://idp.example.com/oauth2/.well-known/openid-configuration
64+
@abstractmethod
65+
def get_openid_config_url(self, hostname: str) -> str:
66+
raise NotImplementedError()
67+
68+
69+
class AzureOAuthEndpointCollection(OAuthEndpointCollection):
70+
DATATRICKS_AZURE_APP = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
71+
72+
def get_scopes_mapping(self, scopes: List[str]) -> List[str]:
73+
# There is no corresponding scopes in Azure, instead, access control will be delegated to Databricks
74+
tenant_id = os.getenv(
75+
"DATABRICKS_AZURE_TENANT_ID",
76+
AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP,
77+
)
78+
azure_scope = f"{tenant_id}/user_impersonation"
79+
mapped_scopes = [azure_scope]
80+
if OAuthScope.OFFLINE_ACCESS in scopes:
81+
mapped_scopes.append(OAuthScope.OFFLINE_ACCESS)
82+
return mapped_scopes
83+
84+
def get_authorization_url(self, hostname: str):
85+
# We need get account specific url, which can be redirected by databricks unified oidc endpoint
86+
return f"{get_databricks_oidc_url(hostname)}/oauth2/v2.0/authorize"
87+
88+
def get_openid_config_url(self, hostname: str):
89+
return "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration"
90+
91+
92+
class AwsOAuthEndpointCollection(OAuthEndpointCollection):
93+
def get_scopes_mapping(self, scopes: List[str]) -> List[str]:
94+
# No scope mapping in AWS
95+
return scopes.copy()
96+
97+
def get_authorization_url(self, hostname: str):
98+
idp_url = get_databricks_oidc_url(hostname)
99+
return f"{idp_url}/oauth2/v2.0/authorize"
100+
101+
def get_openid_config_url(self, hostname: str):
102+
idp_url = get_databricks_oidc_url(hostname)
103+
return f"{idp_url}/.well-known/oauth-authorization-server"
104+
105+
106+
def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]:
107+
if cloud == CloudType.AWS:
108+
return AwsOAuthEndpointCollection()
109+
elif cloud == CloudType.AZURE:
110+
return AzureOAuthEndpointCollection()
111+
else:
112+
return None

src/databricks/sql/auth/oauth.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,22 @@
1414
from requests.exceptions import RequestException
1515

1616
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
17+
from databricks.sql.auth.endpoint import OAuthEndpointCollection
1718

1819
logger = logging.getLogger(__name__)
1920

2021

2122
class OAuthManager:
22-
OIDC_REDIRECTOR_PATH = "oidc"
23-
24-
def __init__(self, port_range: List[int], client_id: str):
23+
def __init__(
24+
self,
25+
port_range: List[int],
26+
client_id: str,
27+
idp_endpoint: OAuthEndpointCollection,
28+
):
2529
self.port_range = port_range
2630
self.client_id = client_id
2731
self.redirect_port = None
32+
self.idp_endpoint = idp_endpoint
2833

2934
@staticmethod
3035
def __token_urlsafe(nbytes=32):
@@ -34,14 +39,14 @@ def __token_urlsafe(nbytes=32):
3439
def __get_redirect_url(redirect_port: int):
3540
return f"http://localhost:{redirect_port}"
3641

37-
@staticmethod
38-
def __fetch_well_known_config(idp_url: str):
39-
known_config_url = f"{idp_url}/.well-known/oauth-authorization-server"
42+
def __fetch_well_known_config(self, hostname: str):
43+
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)
44+
4045
try:
4146
response = requests.get(url=known_config_url)
4247
except RequestException as e:
4348
logger.error(
44-
f"Unable to fetch OAuth configuration from {idp_url}.\n"
49+
f"Unable to fetch OAuth configuration from {known_config_url}.\n"
4550
"Verify it is a valid workspace URL and that OAuth is "
4651
"enabled on this account."
4752
)
@@ -50,7 +55,7 @@ def __fetch_well_known_config(idp_url: str):
5055
if response.status_code != 200:
5156
msg = (
5257
f"Received status {response.status_code} OAuth configuration from "
53-
f"{idp_url}.\n Verify it is a valid workspace URL and "
58+
f"{known_config_url}.\n Verify it is a valid workspace URL and "
5459
"that OAuth is enabled on this account."
5560
)
5661
logger.error(msg)
@@ -59,18 +64,12 @@ def __fetch_well_known_config(idp_url: str):
5964
return response.json()
6065
except requests.exceptions.JSONDecodeError as e:
6166
logger.error(
62-
f"Unable to decode OAuth configuration from {idp_url}.\n"
67+
f"Unable to decode OAuth configuration from {known_config_url}.\n"
6368
"Verify it is a valid workspace URL and that OAuth is "
6469
"enabled on this account."
6570
)
6671
raise e
6772

68-
@staticmethod
69-
def __get_idp_url(host: str):
70-
maybe_scheme = "https://" if not host.startswith("https://") else ""
71-
maybe_trailing_slash = "/" if not host.endswith("/") else ""
72-
return f"{maybe_scheme}{host}{maybe_trailing_slash}{OAuthManager.OIDC_REDIRECTOR_PATH}"
73-
7473
@staticmethod
7574
def __get_challenge():
7675
verifier_string = OAuthManager.__token_urlsafe(32)
@@ -154,8 +153,7 @@ def __send_token_request(token_request_url, data):
154153
return response.json()
155154

156155
def __send_refresh_token_request(self, hostname, refresh_token):
157-
idp_url = OAuthManager.__get_idp_url(hostname)
158-
oauth_config = OAuthManager.__fetch_well_known_config(idp_url)
156+
oauth_config = self.__fetch_well_known_config(hostname)
159157
token_request_url = oauth_config["token_endpoint"]
160158
client = oauthlib.oauth2.WebApplicationClient(self.client_id)
161159
token_request_body = client.prepare_refresh_body(
@@ -215,14 +213,15 @@ def check_and_refresh_access_token(
215213
return fresh_access_token, fresh_refresh_token, True
216214

217215
def get_tokens(self, hostname: str, scope=None):
218-
idp_url = self.__get_idp_url(hostname)
219-
oauth_config = self.__fetch_well_known_config(idp_url)
216+
oauth_config = self.__fetch_well_known_config(hostname)
220217
# We are going to override oauth_config["authorization_endpoint"] use the
221218
# /oidc redirector on the hostname, which may inject additional parameters.
222-
auth_url = f"{hostname}oidc/v1/authorize"
219+
auth_url = self.idp_endpoint.get_authorization_url(hostname)
220+
223221
state = OAuthManager.__token_urlsafe(16)
224222
(verifier, challenge) = OAuthManager.__get_challenge()
225223
client = oauthlib.oauth2.WebApplicationClient(self.client_id)
224+
226225
try:
227226
auth_response = self.__get_authorization_code(
228227
client, auth_url, scope, state, challenge

src/databricks/sql/experimental/oauth_persistence.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ def read(self, hostname: str) -> Optional[OAuthToken]:
2727
pass
2828

2929

30+
class OAuthPersistenceCache(OAuthPersistence):
31+
def __init__(self):
32+
self.tokens = {}
33+
34+
def persist(self, hostname: str, oauth_token: OAuthToken):
35+
self.tokens[hostname] = oauth_token
36+
37+
def read(self, hostname: str) -> Optional[OAuthToken]:
38+
return self.tokens.get(hostname)
39+
40+
3041
# Note this is only intended to be used for development
3142
class DevOnlyFilePersistence(OAuthPersistence):
3243
def __init__(self, file_path):

tests/unit/test_auth.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import unittest
2+
import pytest
3+
from typing import Optional
4+
from unittest.mock import patch
25

36
from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider, ExternalAuthProvider
47
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
8+
from databricks.sql.auth.oauth import OAuthManager
9+
from databricks.sql.auth.authenticators import DatabricksOAuthProvider
10+
from databricks.sql.auth.endpoint import CloudType, AwsOAuthEndpointCollection, AzureOAuthEndpointCollection
511
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory
12+
from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache
613

714

815
class Auth(unittest.TestCase):
@@ -38,6 +45,39 @@ def test_noop_auth_provider(self):
3845
self.assertEqual(len(http_request.keys()), 1)
3946
self.assertEqual(http_request['myKey'], 'myVal')
4047

48+
@patch.object(OAuthManager, "check_and_refresh_access_token")
49+
@patch.object(OAuthManager, "get_tokens")
50+
def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh):
51+
client_id = "mock-id"
52+
scopes = ["offline_access", "sql"]
53+
access_token = "mock_token"
54+
refresh_token = "mock_refresh_token"
55+
mock_get_tokens.return_value = (access_token, refresh_token)
56+
mock_check_and_refresh.return_value = (access_token, refresh_token, False)
57+
58+
params = [(CloudType.AWS, "foo.cloud.databricks.com", AwsOAuthEndpointCollection, "offline_access sql"),
59+
(CloudType.AZURE, "foo.1.azuredatabricks.net", AzureOAuthEndpointCollection,
60+
f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access")]
61+
62+
for cloud_type, host, expected_endpoint_type, expected_scopes in params:
63+
with self.subTest(cloud_type.value):
64+
oauth_persistence = OAuthPersistenceCache()
65+
auth_provider = DatabricksOAuthProvider(hostname=host,
66+
oauth_persistence=oauth_persistence,
67+
redirect_port_range=[8020],
68+
client_id=client_id,
69+
scopes=scopes)
70+
71+
self.assertIsInstance(auth_provider.oauth_manager.idp_endpoint, expected_endpoint_type)
72+
self.assertEqual(auth_provider.oauth_manager.port_range, [8020])
73+
self.assertEqual(auth_provider.oauth_manager.client_id, client_id)
74+
self.assertEqual(oauth_persistence.read(host).refresh_token, refresh_token)
75+
mock_get_tokens.assert_called_with(hostname=host, scope=expected_scopes)
76+
77+
headers = {}
78+
auth_provider.add_headers(headers)
79+
self.assertEqual(headers['Authorization'], f"Bearer {access_token}")
80+
4181
def test_external_provider(self):
4282
class MyProvider(CredentialsProvider):
4383
def auth_type(self) -> str:

0 commit comments

Comments
 (0)