Skip to content

Commit d3df719

Browse files
Refactor codebase to use a unified http client (databricks#673)
* Refactor codebase to use a unified http client Signed-off-by: Vikrant Puppala <[email protected]> * Some more fixes and aligned tests Signed-off-by: Vikrant Puppala <[email protected]> * Fix all tests Signed-off-by: Vikrant Puppala <[email protected]> * fmt Signed-off-by: Vikrant Puppala <[email protected]> * fix e2e Signed-off-by: Vikrant Puppala <[email protected]> * fix unit Signed-off-by: Vikrant Puppala <[email protected]> * more fixes Signed-off-by: Vikrant Puppala <[email protected]> * more fixes Signed-off-by: Vikrant Puppala <[email protected]> * review comments Signed-off-by: Vikrant Puppala <[email protected]> * fix warnings Signed-off-by: Vikrant Puppala <[email protected]> * fix check-types Signed-off-by: Vikrant Puppala <[email protected]> * remove separate http client for telemetry Signed-off-by: Vikrant Puppala <[email protected]> * more clean up Signed-off-by: Vikrant Puppala <[email protected]> * more fixes Signed-off-by: Vikrant Puppala <[email protected]> * more fixes Signed-off-by: Vikrant Puppala <[email protected]> * remove finally Signed-off-by: Vikrant Puppala <[email protected]> --------- Signed-off-by: Vikrant Puppala <[email protected]>
1 parent fd81c5a commit d3df719

32 files changed

+925
-718
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from databricks.sql.auth.common import AuthType, ClientContext
1111

1212

13-
def get_auth_provider(cfg: ClientContext):
13+
def get_auth_provider(cfg: ClientContext, http_client):
1414
if cfg.credentials_provider:
1515
return ExternalAuthProvider(cfg.credentials_provider)
1616
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
@@ -19,6 +19,7 @@ def get_auth_provider(cfg: ClientContext):
1919
cfg.hostname,
2020
cfg.azure_client_id,
2121
cfg.azure_client_secret,
22+
http_client,
2223
cfg.azure_tenant_id,
2324
cfg.azure_workspace_resource_id,
2425
)
@@ -34,6 +35,7 @@ def get_auth_provider(cfg: ClientContext):
3435
cfg.oauth_redirect_port_range,
3536
cfg.oauth_client_id,
3637
cfg.oauth_scopes,
38+
http_client,
3739
cfg.auth_type,
3840
)
3941
elif cfg.access_token is not None:
@@ -53,6 +55,8 @@ def get_auth_provider(cfg: ClientContext):
5355
cfg.oauth_redirect_port_range,
5456
cfg.oauth_client_id,
5557
cfg.oauth_scopes,
58+
http_client,
59+
cfg.auth_type or AuthType.DATABRICKS_OAUTH.value,
5660
)
5761
else:
5862
raise RuntimeError("No valid authentication settings!")
@@ -79,7 +83,7 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
7983
)
8084

8185

82-
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
86+
def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs):
8387
# TODO : unify all the auth mechanisms with the Python SDK
8488

8589
auth_type = kwargs.get("auth_type")
@@ -111,4 +115,4 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
111115
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
112116
credentials_provider=kwargs.get("credentials_provider"),
113117
)
114-
return get_auth_provider(cfg)
118+
return get_auth_provider(cfg, http_client)

src/databricks/sql/auth/authenticators.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
redirect_port_range: List[int],
6464
client_id: str,
6565
scopes: List[str],
66+
http_client,
6667
auth_type: str = "databricks-oauth",
6768
):
6869
try:
@@ -79,6 +80,7 @@ def __init__(
7980
port_range=redirect_port_range,
8081
client_id=client_id,
8182
idp_endpoint=idp_endpoint,
83+
http_client=http_client,
8284
)
8385
self._hostname = hostname
8486
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes)
@@ -188,6 +190,7 @@ def __init__(
188190
hostname,
189191
azure_client_id,
190192
azure_client_secret,
193+
http_client,
191194
azure_tenant_id=None,
192195
azure_workspace_resource_id=None,
193196
):
@@ -196,8 +199,9 @@ def __init__(
196199
self.azure_client_secret = azure_client_secret
197200
self.azure_workspace_resource_id = azure_workspace_resource_id
198201
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
199-
hostname
202+
hostname, http_client
200203
)
204+
self._http_client = http_client
201205

202206
def auth_type(self) -> str:
203207
return AuthType.AZURE_SP_M2M.value
@@ -207,6 +211,7 @@ def get_token_source(self, resource: str) -> RefreshableTokenSource:
207211
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
208212
client_id=self.azure_client_id,
209213
client_secret=self.azure_client_secret,
214+
http_client=self._http_client,
210215
extra_params={"resource": resource},
211216
)
212217

src/databricks/sql/auth/common.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import logging
33
from typing import Optional, List
44
from urllib.parse import urlparse
5-
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod
5+
from databricks.sql.auth.retry import DatabricksRetryPolicy
6+
from databricks.sql.common.http import HttpMethod
67

78
logger = logging.getLogger(__name__)
89

@@ -36,6 +37,21 @@ def __init__(
3637
tls_client_cert_file: Optional[str] = None,
3738
oauth_persistence=None,
3839
credentials_provider=None,
40+
# HTTP client configuration parameters
41+
ssl_options=None, # SSLOptions type
42+
socket_timeout: Optional[float] = None,
43+
retry_stop_after_attempts_count: Optional[int] = None,
44+
retry_delay_min: Optional[float] = None,
45+
retry_delay_max: Optional[float] = None,
46+
retry_stop_after_attempts_duration: Optional[float] = None,
47+
retry_delay_default: Optional[float] = None,
48+
retry_dangerous_codes: Optional[List[int]] = None,
49+
http_proxy: Optional[str] = None,
50+
proxy_username: Optional[str] = None,
51+
proxy_password: Optional[str] = None,
52+
pool_connections: Optional[int] = None,
53+
pool_maxsize: Optional[int] = None,
54+
user_agent: Optional[str] = None,
3955
):
4056
self.hostname = hostname
4157
self.access_token = access_token
@@ -52,6 +68,24 @@ def __init__(
5268
self.oauth_persistence = oauth_persistence
5369
self.credentials_provider = credentials_provider
5470

71+
# HTTP client configuration
72+
self.ssl_options = ssl_options
73+
self.socket_timeout = socket_timeout
74+
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5
75+
self.retry_delay_min = retry_delay_min or 1.0
76+
self.retry_delay_max = retry_delay_max or 10.0
77+
self.retry_stop_after_attempts_duration = (
78+
retry_stop_after_attempts_duration or 300.0
79+
)
80+
self.retry_delay_default = retry_delay_default or 5.0
81+
self.retry_dangerous_codes = retry_dangerous_codes or []
82+
self.http_proxy = http_proxy
83+
self.proxy_username = proxy_username
84+
self.proxy_password = proxy_password
85+
self.pool_connections = pool_connections or 10
86+
self.pool_maxsize = pool_maxsize or 20
87+
self.user_agent = user_agent
88+
5589

5690
def get_effective_azure_login_app_id(hostname) -> str:
5791
"""
@@ -69,7 +103,7 @@ def get_effective_azure_login_app_id(hostname) -> str:
69103
return AzureAppId.PROD.value[1]
70104

71105

72-
def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
106+
def get_azure_tenant_id_from_host(host: str, http_client) -> str:
73107
"""
74108
Load the Azure tenant ID from the Azure Databricks login page.
75109
@@ -78,23 +112,20 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
78112
the Azure login page, and the tenant ID is extracted from the redirect URL.
79113
"""
80114

81-
if http_client is None:
82-
http_client = DatabricksHttpClient.get_instance()
83-
84115
login_url = f"{host}/aad/auth"
85116
logger.debug("Loading tenant ID from %s", login_url)
86-
with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp:
87-
if resp.status_code // 100 != 3:
117+
118+
with http_client.request_context(HttpMethod.GET, login_url) as resp:
119+
entra_id_endpoint = resp.retries.history[-1].redirect_location
120+
if entra_id_endpoint is None:
88121
raise ValueError(
89-
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}"
122+
f"No Location header in response from {login_url}: {entra_id_endpoint}"
90123
)
91-
entra_id_endpoint = resp.headers.get("Location")
92-
if entra_id_endpoint is None:
93-
raise ValueError(f"No Location header in response from {login_url}")
94-
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
95-
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
96-
url = urlparse(entra_id_endpoint)
97-
path_segments = url.path.split("/")
98-
if len(path_segments) < 2:
99-
raise ValueError(f"Invalid path in Location header: {url.path}")
100-
return path_segments[1]
124+
125+
# The final redirect URL has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
126+
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
127+
url = urlparse(entra_id_endpoint)
128+
path_segments = url.path.split("/")
129+
if len(path_segments) < 2:
130+
raise ValueError(f"Invalid path in Location header: {url.path}")
131+
return path_segments[1]

src/databricks/sql/auth/oauth.py

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
from typing import List, Optional
1010

1111
import oauthlib.oauth2
12-
import requests
1312
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
14-
from requests.exceptions import RequestException
15-
from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader
13+
from databricks.sql.common.http import HttpMethod, HttpHeader
1614
from databricks.sql.common.http import OAuthResponse
1715
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
1816
from databricks.sql.auth.endpoint import OAuthEndpointCollection
@@ -63,33 +61,19 @@ def refresh(self) -> Token:
6361
pass
6462

6563

66-
class IgnoreNetrcAuth(requests.auth.AuthBase):
67-
"""This auth method is a no-op.
68-
69-
We use it to force requestslib to not use .netrc to write auth headers
70-
when making .post() requests to the oauth token endpoints, since these
71-
don't require authentication.
72-
73-
In cases where .netrc is outdated or corrupt, these requests will fail.
74-
75-
See issue #121
76-
"""
77-
78-
def __call__(self, r):
79-
return r
80-
81-
8264
class OAuthManager:
8365
def __init__(
8466
self,
8567
port_range: List[int],
8668
client_id: str,
8769
idp_endpoint: OAuthEndpointCollection,
70+
http_client,
8871
):
8972
self.port_range = port_range
9073
self.client_id = client_id
9174
self.redirect_port = None
9275
self.idp_endpoint = idp_endpoint
76+
self.http_client = http_client
9377

9478
@staticmethod
9579
def __token_urlsafe(nbytes=32):
@@ -103,8 +87,11 @@ def __fetch_well_known_config(self, hostname: str):
10387
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)
10488

10589
try:
106-
response = requests.get(url=known_config_url, auth=IgnoreNetrcAuth())
107-
except RequestException as e:
90+
response = self.http_client.request(HttpMethod.GET, url=known_config_url)
91+
# Convert urllib3 response to requests-like response for compatibility
92+
response.status_code = response.status
93+
response.json = lambda: json.loads(response.data.decode())
94+
except Exception as e:
10895
logger.error(
10996
f"Unable to fetch OAuth configuration from {known_config_url}.\n"
11097
"Verify it is a valid workspace URL and that OAuth is "
@@ -122,7 +109,7 @@ def __fetch_well_known_config(self, hostname: str):
122109
raise RuntimeError(msg)
123110
try:
124111
return response.json()
125-
except requests.exceptions.JSONDecodeError as e:
112+
except Exception as e:
126113
logger.error(
127114
f"Unable to decode OAuth configuration from {known_config_url}.\n"
128115
"Verify it is a valid workspace URL and that OAuth is "
@@ -203,16 +190,17 @@ def __send_auth_code_token_request(
203190
data = f"{token_request_body}&code_verifier={verifier}"
204191
return self.__send_token_request(token_request_url, data)
205192

206-
@staticmethod
207-
def __send_token_request(token_request_url, data):
193+
def __send_token_request(self, token_request_url, data):
208194
headers = {
209195
"Accept": "application/json",
210196
"Content-Type": "application/x-www-form-urlencoded",
211197
}
212-
response = requests.post(
213-
url=token_request_url, data=data, headers=headers, auth=IgnoreNetrcAuth()
198+
# Use unified HTTP client
199+
response = self.http_client.request(
200+
HttpMethod.POST, url=token_request_url, body=data, headers=headers
214201
)
215-
return response.json()
202+
# Convert urllib3 response to dict for compatibility
203+
return json.loads(response.data.decode())
216204

217205
def __send_refresh_token_request(self, hostname, refresh_token):
218206
oauth_config = self.__fetch_well_known_config(hostname)
@@ -221,7 +209,7 @@ def __send_refresh_token_request(self, hostname, refresh_token):
221209
token_request_body = client.prepare_refresh_body(
222210
refresh_token=refresh_token, client_id=client.client_id
223211
)
224-
return OAuthManager.__send_token_request(token_request_url, token_request_body)
212+
return self.__send_token_request(token_request_url, token_request_body)
225213

226214
@staticmethod
227215
def __get_tokens_from_response(oauth_response):
@@ -320,14 +308,15 @@ def __init__(
320308
token_url,
321309
client_id,
322310
client_secret,
311+
http_client,
323312
extra_params: dict = {},
324313
):
325314
self.client_id = client_id
326315
self.client_secret = client_secret
327316
self.token_url = token_url
328317
self.extra_params = extra_params
329318
self.token: Optional[Token] = None
330-
self._http_client = DatabricksHttpClient.get_instance()
319+
self._http_client = http_client
331320

332321
def get_token(self) -> Token:
333322
if self.token is None or self.token.is_expired():
@@ -348,17 +337,17 @@ def refresh(self) -> Token:
348337
}
349338
)
350339

351-
with self._http_client.execute(
352-
method=HttpMethod.POST, url=self.token_url, headers=headers, data=data
353-
) as response:
354-
if response.status_code == 200:
355-
oauth_response = OAuthResponse(**response.json())
356-
return Token(
357-
oauth_response.access_token,
358-
oauth_response.token_type,
359-
oauth_response.refresh_token,
360-
)
361-
else:
362-
raise Exception(
363-
f"Failed to get token: {response.status_code} {response.text}"
364-
)
340+
response = self._http_client.request(
341+
method=HttpMethod.POST, url=self.token_url, headers=headers, body=data
342+
)
343+
if response.status == 200:
344+
oauth_response = OAuthResponse(**json.loads(response.data.decode("utf-8")))
345+
return Token(
346+
oauth_response.access_token,
347+
oauth_response.token_type,
348+
oauth_response.refresh_token,
349+
)
350+
else:
351+
raise Exception(
352+
f"Failed to get token: {response.status} {response.data.decode('utf-8')}"
353+
)

src/databricks/sql/auth/retry.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,14 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
355355
logger.info(f"Received status code {status_code} for {method} request")
356356

357357
# Request succeeded. Don't retry.
358-
if status_code == 200:
359-
return False, "200 codes are not retried"
358+
if status_code // 100 <= 3:
359+
return False, "2xx/3xx codes are not retried"
360+
361+
if status_code == 400:
362+
return (
363+
False,
364+
"Received 400 - BAD_REQUEST. Please check the request parameters.",
365+
)
360366

361367
if status_code == 401:
362368
return (

src/databricks/sql/backend/sea/queue.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def build_queue(
5050
max_download_threads: int,
5151
sea_client: SeaDatabricksClient,
5252
lz4_compressed: bool,
53+
http_client,
5354
) -> ResultSetQueue:
5455
"""
5556
Factory method to build a result set queue for SEA backend.
@@ -94,6 +95,7 @@ def build_queue(
9495
total_chunk_count=manifest.total_chunk_count,
9596
lz4_compressed=lz4_compressed,
9697
description=description,
98+
http_client=http_client,
9799
)
98100
raise ProgrammingError("Invalid result format")
99101

@@ -309,6 +311,7 @@ def __init__(
309311
sea_client: SeaDatabricksClient,
310312
statement_id: str,
311313
total_chunk_count: int,
314+
http_client,
312315
lz4_compressed: bool = False,
313316
description: List[Tuple] = [],
314317
):
@@ -337,6 +340,7 @@ def __init__(
337340
# TODO: fix these arguments when telemetry is implemented in SEA
338341
session_id_hex=None,
339342
chunk_id=0,
343+
http_client=http_client,
340344
)
341345

342346
logger.debug(

0 commit comments

Comments
 (0)