Skip to content

Commit 1b469c0

Browse files
jackyhu-dbJesse Whitehouse
and
Jesse Whitehouse
authored
[PECO-1411] Support Databricks OAuth on GCP (#338)
* [PECO-1411] Support OAuth InHouse on GCP Signed-off-by: Jacky Hu <[email protected]> * Update changelog Signed-off-by: Jesse Whitehouse <[email protected]> --------- Signed-off-by: Jacky Hu <[email protected]> Signed-off-by: Jesse Whitehouse <[email protected]> Co-authored-by: Jesse Whitehouse <[email protected]>
1 parent 5a06ccd commit 1b469c0

File tree

4 files changed

+14
-7
lines changed

4 files changed

+14
-7
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# 3.0.3 (2024-02-02)
44

5+
- Add support in-house OAuth on GCP (#338)
56
- Revised docstrings and examples for OAuth (#339)
67
- Redact the URL query parameters from the urllib3.connectionpool logs (#341)
78

src/databricks/sql/auth/auth.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,10 @@ def normalize_host_name(hostname: str):
8888

8989

9090
def get_client_id_and_redirect_port(hostname: str):
91+
cloud_type = infer_cloud_from_host(hostname)
9192
return (
9293
(PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE)
93-
if infer_cloud_from_host(hostname) == CloudType.AWS
94+
if cloud_type == CloudType.AWS or cloud_type == CloudType.GCP
9495
else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE)
9596
)
9697

src/databricks/sql/auth/endpoint.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class OAuthScope:
2121
class CloudType(Enum):
2222
AWS = "aws"
2323
AZURE = "azure"
24+
GCP = "gcp"
2425

2526

2627
DATABRICKS_AWS_DOMAINS = [
@@ -34,6 +35,7 @@ class CloudType(Enum):
3435
".databricks.azure.cn",
3536
".databricks.azure.us",
3637
]
38+
DATABRICKS_GCP_DOMAINS = [".gcp.databricks.com"]
3739

3840

3941
# Infer cloud type from Databricks SQL instance hostname
@@ -45,6 +47,8 @@ def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
4547
return CloudType.AZURE
4648
elif any(e for e in DATABRICKS_AWS_DOMAINS if host.endswith(e)):
4749
return CloudType.AWS
50+
elif any(e for e in DATABRICKS_GCP_DOMAINS if host.endswith(e)):
51+
return CloudType.GCP
4852
else:
4953
return None
5054

@@ -94,7 +98,7 @@ def get_openid_config_url(self, hostname: str):
9498
return "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration"
9599

96100

97-
class AwsOAuthEndpointCollection(OAuthEndpointCollection):
101+
class InHouseOAuthEndpointCollection(OAuthEndpointCollection):
98102
def get_scopes_mapping(self, scopes: List[str]) -> List[str]:
99103
# No scope mapping in AWS
100104
return scopes.copy()
@@ -109,8 +113,8 @@ def get_openid_config_url(self, hostname: str):
109113

110114

111115
def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]:
112-
if cloud == CloudType.AWS:
113-
return AwsOAuthEndpointCollection()
116+
if cloud == CloudType.AWS or cloud == CloudType.GCP:
117+
return InHouseOAuthEndpointCollection()
114118
elif cloud == CloudType.AZURE:
115119
return AzureOAuthEndpointCollection()
116120
else:

tests/unit/test_auth.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
88
from databricks.sql.auth.oauth import OAuthManager
99
from databricks.sql.auth.authenticators import DatabricksOAuthProvider
10-
from databricks.sql.auth.endpoint import CloudType, AwsOAuthEndpointCollection, AzureOAuthEndpointCollection
10+
from databricks.sql.auth.endpoint import CloudType, InHouseOAuthEndpointCollection, AzureOAuthEndpointCollection
1111
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory
1212
from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache
1313

@@ -55,9 +55,10 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh):
5555
mock_get_tokens.return_value = (access_token, refresh_token)
5656
mock_check_and_refresh.return_value = (access_token, refresh_token, False)
5757

58-
params = [(CloudType.AWS, "foo.cloud.databricks.com", AwsOAuthEndpointCollection, "offline_access sql"),
58+
params = [(CloudType.AWS, "foo.cloud.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql"),
5959
(CloudType.AZURE, "foo.1.azuredatabricks.net", AzureOAuthEndpointCollection,
60-
f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access")]
60+
f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access"),
61+
(CloudType.GCP, "foo.gcp.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql")]
6162

6263
for cloud_type, host, expected_endpoint_type, expected_scopes in params:
6364
with self.subTest(cloud_type.value):

0 commit comments

Comments
 (0)