Skip to content

Commit 8b97bb8

Browse files
committed
fix the integration workflow and linting
1 parent 8187766 commit 8b97bb8

File tree

2 files changed

+56
-20
lines changed

2 files changed

+56
-20
lines changed

.github/workflows/integration.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,4 @@ jobs:
6262
- name: Run e2e tests
6363
run: poetry run python -m pytest tests/e2e
6464
- name: Run SQL Alchemy tests
65-
run: poetry run python -m pytest sqlalchemy/tests/test_local
65+
run: poetry run python -m pytest src/databricks/sqlalchemy/tests/test_local

src/databricks/sql/client.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ def read(self) -> Optional[OAuthToken]:
205205
self.disable_pandas = kwargs.get("_disable_pandas", False)
206206
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
207207

208-
auth_provider = get_python_sql_connector_auth_provider(server_hostname, **kwargs)
208+
auth_provider = get_python_sql_connector_auth_provider(
209+
server_hostname, **kwargs
210+
)
209211

210212
if not kwargs.get("_user_agent_entry"):
211213
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
@@ -276,7 +278,8 @@ def __exit__(self, exc_type, exc_value, traceback):
276278
def __del__(self):
277279
if self.open:
278280
logger.debug(
279-
"Closing unclosed connection for session " "{}".format(self.get_session_id_hex())
281+
"Closing unclosed connection for session "
282+
"{}".format(self.get_session_id_hex())
280283
)
281284
try:
282285
self._close(close_cursors=False)
@@ -356,9 +359,13 @@ def _close(self, close_cursors=True) -> None:
356359
logger.info("Session was closed by a prior request")
357360
except DatabaseError as e:
358361
if "Invalid SessionHandle" in str(e):
359-
logger.warning(f"Attempted to close session that was already closed: {e}")
362+
logger.warning(
363+
f"Attempted to close session that was already closed: {e}"
364+
)
360365
else:
361-
logger.warning(f"Attempt to close session raised an exception at the server: {e}")
366+
logger.warning(
367+
f"Attempt to close session raised an exception at the server: {e}"
368+
)
362369
except Exception as e:
363370
logger.error(f"Attempt to close session raised a local exception: {e}")
364371

@@ -441,7 +448,9 @@ def _all_dbsql_parameters_are_named(self, params: List[TDbsqlParameter]) -> bool
441448
"""Return True if all members of the list have a non-null .name attribute"""
442449
return all([i.name is not None for i in params])
443450

444-
def _normalize_tparametersequence(self, params: TParameterSequence) -> List[TDbsqlParameter]:
451+
def _normalize_tparametersequence(
452+
self, params: TParameterSequence
453+
) -> List[TDbsqlParameter]:
445454
"""Retains the same order as the input list."""
446455

447456
output: List[TDbsqlParameter] = []
@@ -453,9 +462,12 @@ def _normalize_tparametersequence(self, params: TParameterSequence) -> List[TDbs
453462

454463
return output
455464

456-
def _normalize_tparameterdict(self, params: TParameterDict) -> List[TDbsqlParameter]:
465+
def _normalize_tparameterdict(
466+
self, params: TParameterDict
467+
) -> List[TDbsqlParameter]:
457468
return [
458-
dbsql_parameter_from_primitive(value=value, name=name) for name, value in params.items()
469+
dbsql_parameter_from_primitive(value=value, name=name)
470+
for name, value in params.items()
459471
]
460472

461473
def _normalize_tparametercollection(
@@ -528,7 +540,8 @@ def _prepare_native_parameters(
528540

529541
stmt = stmt
530542
output = [
531-
p.as_tspark_param(named=param_structure == ParameterStructure.NAMED) for p in params
543+
p.as_tspark_param(named=param_structure == ParameterStructure.NAMED)
544+
for p in params
532545
]
533546

534547
return stmt, output
@@ -544,7 +557,9 @@ def _check_not_closed(self):
544557
if not self.open:
545558
raise Error("Attempting operation on closed cursor")
546559

547-
def _handle_staging_operation(self, staging_allowed_local_path: Union[None, str, List[str]]):
560+
def _handle_staging_operation(
561+
self, staging_allowed_local_path: Union[None, str, List[str]]
562+
):
548563
"""Fetch the HTTP request instruction from a staging ingestion command
549564
and call the designated handler.
550565
@@ -561,7 +576,9 @@ def _handle_staging_operation(self, staging_allowed_local_path: Union[None, str,
561576
"You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands"
562577
)
563578

564-
abs_staging_allowed_local_paths = [os.path.abspath(i) for i in _staging_allowed_local_paths]
579+
abs_staging_allowed_local_paths = [
580+
os.path.abspath(i) for i in _staging_allowed_local_paths
581+
]
565582

566583
assert self.active_result_set is not None
567584
row = self.active_result_set.fetchone()
@@ -589,7 +606,9 @@ def _handle_staging_operation(self, staging_allowed_local_path: Union[None, str,
589606
)
590607

591608
# May be real headers, or could be json string
592-
headers = json.loads(row.headers) if isinstance(row.headers, str) else row.headers
609+
headers = (
610+
json.loads(row.headers) if isinstance(row.headers, str) else row.headers
611+
)
593612

594613
handler_args = {
595614
"presigned_url": row.presignedUrl,
@@ -616,7 +635,9 @@ def _handle_staging_operation(self, staging_allowed_local_path: Union[None, str,
616635
+ "Supported operations are GET, PUT, and REMOVE"
617636
)
618637

619-
def _handle_staging_put(self, presigned_url: str, local_file: str, headers: dict = None):
638+
def _handle_staging_put(
639+
self, presigned_url: str, local_file: str, headers: dict = None
640+
):
620641
"""Make an HTTP PUT request
621642
622643
Raise an exception if request fails. Returns no data.
@@ -639,15 +660,19 @@ def _handle_staging_put(self, presigned_url: str, local_file: str, headers: dict
639660
# fmt: on
640661

641662
if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]:
642-
raise Error(f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}")
663+
raise Error(
664+
f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}"
665+
)
643666

644667
if r.status_code == ACCEPTED:
645668
logger.debug(
646669
f"Response code {ACCEPTED} from server indicates ingestion command was accepted "
647670
+ "but not yet applied on the server. It's possible this command may fail later."
648671
)
649672

650-
def _handle_staging_get(self, local_file: str, presigned_url: str, headers: dict = None):
673+
def _handle_staging_get(
674+
self, local_file: str, presigned_url: str, headers: dict = None
675+
):
651676
"""Make an HTTP GET request, create a local file with the received data
652677
653678
Raise an exception if request fails. Returns no data.
@@ -661,7 +686,9 @@ def _handle_staging_get(self, local_file: str, presigned_url: str, headers: dict
661686
# response.ok verifies the status code is not between 400-600.
662687
# Any 2xx or 3xx will evaluate r.ok == True
663688
if not r.ok:
664-
raise Error(f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}")
689+
raise Error(
690+
f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}"
691+
)
665692

666693
with open(local_file, "wb") as fp:
667694
fp.write(r.content)
@@ -672,7 +699,9 @@ def _handle_staging_remove(self, presigned_url: str, headers: dict = None):
672699
r = requests.delete(url=presigned_url, headers=headers)
673700

674701
if not r.ok:
675-
raise Error(f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}")
702+
raise Error(
703+
f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}"
704+
)
676705

677706
def execute(
678707
self,
@@ -967,7 +996,8 @@ def cancel(self) -> None:
967996
self.thrift_backend.cancel_command(self.active_op_handle)
968997
else:
969998
logger.warning(
970-
"Attempting to cancel a command, but there is no " "currently executing command"
999+
"Attempting to cancel a command, but there is no "
1000+
"currently executing command"
9711001
)
9721002

9731003
def close(self) -> None:
@@ -1085,7 +1115,9 @@ def _convert_arrow_table(self, table):
10851115
ResultRow = Row(*column_names)
10861116

10871117
if self.connection.disable_pandas is True:
1088-
return [ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns())]
1118+
return [
1119+
ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns())
1120+
]
10891121

10901122
# Need to use nullable types, as otherwise type can change when there are missing values.
10911123
# See https://arrow.apache.org/docs/python/pandas.html#nullable-types
@@ -1132,7 +1164,11 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:
11321164
n_remaining_rows = size - results.num_rows
11331165
self._next_row_index += results.num_rows
11341166

1135-
while n_remaining_rows > 0 and not self.has_been_closed_server_side and self.has_more_rows:
1167+
while (
1168+
n_remaining_rows > 0
1169+
and not self.has_been_closed_server_side
1170+
and self.has_more_rows
1171+
):
11361172
self._fill_results_buffer()
11371173
partial_results = self.results.next_n_rows(n_remaining_rows)
11381174
results = pyarrow.concat_tables([results, partial_results])

0 commit comments

Comments
 (0)