5
5
import re
6
6
from typing import Any , Dict , Tuple , List , Optional , Union , TYPE_CHECKING , Set
7
7
8
- from databricks .sql .backend .sea .models .base import ResultManifest
8
+ from databricks .sql .backend .sea .models .base import ExternalLink , ResultManifest
9
9
from databricks .sql .backend .sea .utils .constants import (
10
10
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP ,
11
11
ResultFormat ,
28
28
BackendType ,
29
29
ExecuteResponse ,
30
30
)
31
- from databricks .sql .exc import DatabaseError , ProgrammingError , ServerOperationError
31
+ from databricks .sql .exc import DatabaseError , ServerOperationError
32
32
from databricks .sql .backend .sea .utils .http_client import SeaHttpClient
33
33
from databricks .sql .types import SSLOptions
34
34
44
44
GetStatementResponse ,
45
45
CreateSessionResponse ,
46
46
)
47
+ from databricks .sql .backend .sea .models .responses import GetChunksResponse
47
48
48
49
logger = logging .getLogger (__name__ )
49
50
@@ -88,6 +89,7 @@ class SeaDatabricksClient(DatabricksClient):
88
89
STATEMENT_PATH = BASE_PATH + "statements"
89
90
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
90
91
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
92
+ CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
91
93
92
94
# SEA constants
93
95
POLL_INTERVAL_SECONDS = 0.2
@@ -123,18 +125,22 @@ def __init__(
123
125
)
124
126
125
127
self ._max_download_threads = kwargs .get ("max_download_threads" , 10 )
128
+ self ._ssl_options = ssl_options
129
+ self ._use_arrow_native_complex_types = kwargs .get (
130
+ "_use_arrow_native_complex_types" , True
131
+ )
126
132
127
133
# Extract warehouse ID from http_path
128
134
self .warehouse_id = self ._extract_warehouse_id (http_path )
129
135
130
136
# Initialize HTTP client
131
- self .http_client = SeaHttpClient (
137
+ self ._http_client = SeaHttpClient (
132
138
server_hostname = server_hostname ,
133
139
port = port ,
134
140
http_path = http_path ,
135
141
http_headers = http_headers ,
136
142
auth_provider = auth_provider ,
137
- ssl_options = ssl_options ,
143
+ ssl_options = self . _ssl_options ,
138
144
** kwargs ,
139
145
)
140
146
@@ -173,7 +179,7 @@ def _extract_warehouse_id(self, http_path: str) -> str:
173
179
f"Note: SEA only works for warehouses."
174
180
)
175
181
logger .error (error_message )
176
- raise ProgrammingError (error_message )
182
+ raise ValueError (error_message )
177
183
178
184
@property
179
185
def max_download_threads (self ) -> int :
@@ -220,7 +226,7 @@ def open_session(
220
226
schema = schema ,
221
227
)
222
228
223
- response = self .http_client ._make_request (
229
+ response = self ._http_client ._make_request (
224
230
method = "POST" , path = self .SESSION_PATH , data = request_data .to_dict ()
225
231
)
226
232
@@ -245,7 +251,7 @@ def close_session(self, session_id: SessionId) -> None:
245
251
session_id: The session identifier returned by open_session()
246
252
247
253
Raises:
248
- ProgrammingError : If the session ID is invalid
254
+ ValueError : If the session ID is invalid
249
255
OperationalError: If there's an error closing the session
250
256
"""
251
257
@@ -260,7 +266,7 @@ def close_session(self, session_id: SessionId) -> None:
260
266
session_id = sea_session_id ,
261
267
)
262
268
263
- self .http_client ._make_request (
269
+ self ._http_client ._make_request (
264
270
method = "DELETE" ,
265
271
path = self .SESSION_PATH_WITH_ID .format (sea_session_id ),
266
272
data = request_data .to_dict (),
@@ -342,7 +348,7 @@ def _results_message_to_execute_response(
342
348
343
349
# Check for compression
344
350
lz4_compressed = (
345
- response .manifest .result_compression == ResultCompression .LZ4_FRAME
351
+ response .manifest .result_compression == ResultCompression .LZ4_FRAME . value
346
352
)
347
353
348
354
execute_response = ExecuteResponse (
@@ -424,7 +430,7 @@ def execute_command(
424
430
enforce_embedded_schema_correctness: Whether to enforce schema correctness
425
431
426
432
Returns:
427
- ResultSet : A SeaResultSet instance for the executed command
433
+ SeaResultSet : A SeaResultSet instance for the executed command
428
434
"""
429
435
430
436
if session_id .backend_type != BackendType .SEA :
@@ -471,7 +477,7 @@ def execute_command(
471
477
result_compression = result_compression ,
472
478
)
473
479
474
- response_data = self .http_client ._make_request (
480
+ response_data = self ._http_client ._make_request (
475
481
method = "POST" , path = self .STATEMENT_PATH , data = request .to_dict ()
476
482
)
477
483
response = ExecuteStatementResponse .from_dict (response_data )
@@ -505,7 +511,7 @@ def cancel_command(self, command_id: CommandId) -> None:
505
511
command_id: Command identifier to cancel
506
512
507
513
Raises:
508
- ProgrammingError : If the command ID is invalid
514
+ ValueError : If the command ID is invalid
509
515
"""
510
516
511
517
if command_id .backend_type != BackendType .SEA :
@@ -516,7 +522,7 @@ def cancel_command(self, command_id: CommandId) -> None:
516
522
raise ValueError ("Not a valid SEA command ID" )
517
523
518
524
request = CancelStatementRequest (statement_id = sea_statement_id )
519
- self .http_client ._make_request (
525
+ self ._http_client ._make_request (
520
526
method = "POST" ,
521
527
path = self .CANCEL_STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
522
528
data = request .to_dict (),
@@ -530,7 +536,7 @@ def close_command(self, command_id: CommandId) -> None:
530
536
command_id: Command identifier to close
531
537
532
538
Raises:
533
- ProgrammingError : If the command ID is invalid
539
+ ValueError : If the command ID is invalid
534
540
"""
535
541
536
542
if command_id .backend_type != BackendType .SEA :
@@ -541,7 +547,7 @@ def close_command(self, command_id: CommandId) -> None:
541
547
raise ValueError ("Not a valid SEA command ID" )
542
548
543
549
request = CloseStatementRequest (statement_id = sea_statement_id )
544
- self .http_client ._make_request (
550
+ self ._http_client ._make_request (
545
551
method = "DELETE" ,
546
552
path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
547
553
data = request .to_dict (),
@@ -558,7 +564,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
558
564
CommandState: The current state of the command
559
565
560
566
Raises:
561
- ProgrammingError : If the command ID is invalid
567
+ ValueError : If the command ID is invalid
562
568
"""
563
569
564
570
if command_id .backend_type != BackendType .SEA :
@@ -569,7 +575,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
569
575
raise ValueError ("Not a valid SEA command ID" )
570
576
571
577
request = GetStatementRequest (statement_id = sea_statement_id )
572
- response_data = self .http_client ._make_request (
578
+ response_data = self ._http_client ._make_request (
573
579
method = "GET" ,
574
580
path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
575
581
data = request .to_dict (),
@@ -609,7 +615,7 @@ def get_execution_result(
609
615
request = GetStatementRequest (statement_id = sea_statement_id )
610
616
611
617
# Get the statement result
612
- response_data = self .http_client ._make_request (
618
+ response_data = self ._http_client ._make_request (
613
619
method = "GET" ,
614
620
path = self .STATEMENT_PATH_WITH_ID .format (sea_statement_id ),
615
621
data = request .to_dict (),
@@ -631,6 +637,35 @@ def get_execution_result(
631
637
arraysize = cursor .arraysize ,
632
638
)
633
639
640
+ def get_chunk_link (self , statement_id : str , chunk_index : int ) -> ExternalLink :
641
+ """
642
+ Get links for chunks starting from the specified index.
643
+ Args:
644
+ statement_id: The statement ID
645
+ chunk_index: The starting chunk index
646
+ Returns:
647
+ ExternalLink: External link for the chunk
648
+ """
649
+
650
+ response_data = self ._http_client ._make_request (
651
+ method = "GET" ,
652
+ path = self .CHUNK_PATH_WITH_ID_AND_INDEX .format (statement_id , chunk_index ),
653
+ )
654
+ response = GetChunksResponse .from_dict (response_data )
655
+
656
+ links = response .external_links or []
657
+ link = next ((l for l in links if l .chunk_index == chunk_index ), None )
658
+ if not link :
659
+ raise ServerOperationError (
660
+ f"No link found for chunk index { chunk_index } " ,
661
+ {
662
+ "operation-id" : statement_id ,
663
+ "diagnostic-info" : None ,
664
+ },
665
+ )
666
+
667
+ return link
668
+
634
669
# == Metadata Operations ==
635
670
636
671
def get_catalogs (
0 commit comments