Skip to content

Chunk download latency #634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: sea-migration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/databricks/sql/backend/sea/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional, Tuple, Union, TYPE_CHECKING

from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
from databricks.sql.telemetry.models.enums import StatementType

try:
import pyarrow
Expand Down Expand Up @@ -134,9 +135,13 @@ def __init__(
super().__init__(
max_download_threads=max_download_threads,
ssl_options=ssl_options,
statement_id=statement_id,
schema_bytes=None,
lz4_compressed=lz4_compressed,
description=description,
# TODO: fix these arguments when telemetry is implemented in SEA
session_id_hex=None,
chunk_id=0,
)

self._sea_client = sea_client
Expand Down
19 changes: 17 additions & 2 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import time
import threading
from typing import List, Optional, Union, Any, TYPE_CHECKING
from uuid import UUID

from databricks.sql.result_set import ThriftResultSet

from databricks.sql.telemetry.models.event import StatementType

if TYPE_CHECKING:
from databricks.sql.client import Cursor
Expand Down Expand Up @@ -900,6 +901,7 @@ def get_execution_result(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
Expand Down Expand Up @@ -1037,6 +1039,7 @@ def execute_command(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def get_catalogs(
Expand Down Expand Up @@ -1077,6 +1080,7 @@ def get_catalogs(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def get_schemas(
Expand Down Expand Up @@ -1123,6 +1127,7 @@ def get_schemas(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def get_tables(
Expand Down Expand Up @@ -1173,6 +1178,7 @@ def get_tables(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def get_columns(
Expand Down Expand Up @@ -1223,6 +1229,7 @@ def get_columns(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def _handle_execute_response(self, resp, cursor):
Expand Down Expand Up @@ -1257,6 +1264,7 @@ def fetch_results(
lz4_compressed: bool,
arrow_schema_bytes,
description,
chunk_id: int,
use_cloud_fetch=True,
):
thrift_handle = command_id.to_thrift_handle()
Expand Down Expand Up @@ -1294,9 +1302,16 @@ def fetch_results(
lz4_compressed=lz4_compressed,
description=description,
ssl_options=self._ssl_options,
session_id_hex=self._session_id_hex,
statement_id=command_id.to_hex_guid(),
chunk_id=chunk_id,
)

return queue, resp.hasMoreRows
return (
queue,
resp.hasMoreRows,
len(resp.results.resultLinks) if resp.results.resultLinks else 0,
)

def cancel_command(self, command_id: CommandId) -> None:
thrift_handle = command_id.to_thrift_handle()
Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging

from databricks.sql.backend.utils.guid_utils import guid_to_hex_id
from databricks.sql.telemetry.models.enums import StatementType
from databricks.sql.thrift_api.TCLIService import ttypes

logger = logging.getLogger(__name__)
Expand Down
4 changes: 3 additions & 1 deletion src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def read(self) -> Optional[OAuthToken]:

driver_connection_params = DriverConnectionParameters(
http_path=http_path,
mode=DatabricksClientType.THRIFT,
mode=DatabricksClientType.SEA
if self.session.use_sea
else DatabricksClientType.THRIFT,
host_info=HostDetails(host_url=server_hostname, port=self.session.port),
auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider),
auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider),
Expand Down
33 changes: 23 additions & 10 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging

from concurrent.futures import ThreadPoolExecutor, Future
from typing import List, Union
from typing import List, Union, Tuple, Optional

from databricks.sql.cloudfetch.downloader import (
ResultSetDownloadHandler,
DownloadableResultSettings,
DownloadedFile,
)
from databricks.sql.types import SSLOptions

from databricks.sql.telemetry.models.event import StatementType
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)
Expand All @@ -22,24 +22,31 @@ def __init__(
max_download_threads: int,
lz4_compressed: bool,
ssl_options: SSLOptions,
session_id_hex: Optional[str],
statement_id: str,
chunk_id: int,
):
self._pending_links: List[TSparkArrowResultLink] = []
for link in links:
self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = []
self.chunk_id = chunk_id
for i, link in enumerate(links, start=chunk_id):
if link.rowCount <= 0:
continue
logger.debug(
"ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format(
link.startRowOffset, link.rowCount
"ResultFileDownloadManager: adding file link, chunk id {}, start offset {}, row count: {}".format(
i, link.startRowOffset, link.rowCount
)
)
self._pending_links.append(link)
self._pending_links.append((i, link))
self.chunk_id += len(links)

self._download_tasks: List[Future[DownloadedFile]] = []
self._max_download_threads: int = max_download_threads
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)

self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self._ssl_options = ssl_options
self.session_id_hex = session_id_hex
self.statement_id = statement_id

def get_next_downloaded_file(
self, next_row_offset: int
Expand Down Expand Up @@ -89,14 +96,19 @@ def _schedule_downloads(self):
while (len(self._download_tasks) < self._max_download_threads) and (
len(self._pending_links) > 0
):
link = self._pending_links.pop(0)
chunk_id, link = self._pending_links.pop(0)
logger.debug(
"- start: {}, row count: {}".format(link.startRowOffset, link.rowCount)
"- chunk: {}, start: {}, row count: {}".format(
chunk_id, link.startRowOffset, link.rowCount
)
)
handler = ResultSetDownloadHandler(
settings=self._downloadable_result_settings,
link=link,
ssl_options=self._ssl_options,
chunk_id=chunk_id,
session_id_hex=self.session_id_hex,
statement_id=self.statement_id,
)
task = self._thread_pool.submit(handler.run)
self._download_tasks.append(task)
Expand All @@ -117,7 +129,8 @@ def add_link(self, link: TSparkArrowResultLink):
link.startRowOffset, link.rowCount
)
)
self._pending_links.append(link)
self._pending_links.append((self.chunk_id, link))
self.chunk_id += 1

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool
Expand Down
14 changes: 12 additions & 2 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Optional

import requests
from requests.adapters import HTTPAdapter, Retry
Expand All @@ -9,6 +10,8 @@
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
from databricks.sql.exc import Error
from databricks.sql.types import SSLOptions
from databricks.sql.telemetry.latency_logger import log_latency
from databricks.sql.telemetry.models.event import StatementType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,11 +69,18 @@ def __init__(
settings: DownloadableResultSettings,
link: TSparkArrowResultLink,
ssl_options: SSLOptions,
chunk_id: int,
session_id_hex: Optional[str],
statement_id: str,
):
self.settings = settings
self.link = link
self._ssl_options = ssl_options
self.chunk_id = chunk_id
self.session_id_hex = session_id_hex
self.statement_id = statement_id

@log_latency(StatementType.QUERY)
def run(self) -> DownloadedFile:
"""
Download the file described in the cloud fetch link.
Expand All @@ -80,8 +90,8 @@ def run(self) -> DownloadedFile:
"""

logger.debug(
"ResultSetDownloadHandler: starting file download, offset {}, row count {}".format(
self.link.startRowOffset, self.link.rowCount
"ResultSetDownloadHandler: starting file download, chunk id {}, offset {}, row count {}".format(
self.chunk_id, self.link.startRowOffset, self.link.rowCount
)
)

Expand Down
12 changes: 11 additions & 1 deletion src/databricks/sql/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ColumnQueue,
)
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
from databricks.sql.telemetry.models.event import StatementType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -192,6 +193,7 @@ def __init__(
connection: "Connection",
execute_response: "ExecuteResponse",
thrift_client: "ThriftDatabricksClient",
session_id_hex: Optional[str],
buffer_size_bytes: int = 104857600,
arraysize: int = 10000,
use_cloud_fetch: bool = True,
Expand All @@ -215,6 +217,7 @@ def __init__(
:param ssl_options: SSL options for cloud fetch
:param is_direct_results: Whether there are more rows to fetch
"""
self.num_downloaded_chunks = 0

# Initialize ThriftResultSet-specific attributes
self._use_cloud_fetch = use_cloud_fetch
Expand All @@ -234,7 +237,12 @@ def __init__(
lz4_compressed=execute_response.lz4_compressed,
description=execute_response.description,
ssl_options=ssl_options,
session_id_hex=session_id_hex,
statement_id=execute_response.command_id.to_hex_guid(),
chunk_id=self.num_downloaded_chunks,
)
if t_row_set and t_row_set.resultLinks:
self.num_downloaded_chunks += len(t_row_set.resultLinks)

# Call parent constructor with common attributes
super().__init__(
Expand All @@ -258,7 +266,7 @@ def __init__(
self._fill_results_buffer()

def _fill_results_buffer(self):
results, is_direct_results = self.backend.fetch_results(
results, is_direct_results, result_links_count = self.backend.fetch_results(
command_id=self.command_id,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
Expand All @@ -267,9 +275,11 @@ def _fill_results_buffer(self):
arrow_schema_bytes=self._arrow_schema_bytes,
description=self.description,
use_cloud_fetch=self._use_cloud_fetch,
chunk_id=self.num_downloaded_chunks,
)
self.results = results
self.is_direct_results = is_direct_results
self.num_downloaded_chunks += result_links_count

def _convert_columnar_table(self, table):
column_names = [c[0] for c in self.description]
Expand Down
4 changes: 2 additions & 2 deletions src/databricks/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def _create_backend(
kwargs: dict,
) -> DatabricksClient:
"""Create and return the appropriate backend client."""
use_sea = kwargs.get("use_sea", False)
self.use_sea = kwargs.get("use_sea", False)

databricks_client_class: Type[DatabricksClient]
if use_sea:
if self.use_sea:
logger.debug("Creating SEA backend client")
databricks_client_class = SeaDatabricksClient
else:
Expand Down
Loading
Loading