Skip to content

SEA: re-fetch links in case of expiry #635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: sea-decouple-link-fetch
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/experimental/tests/test_sea_sync_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import os
import sys
import logging
import time
from databricks.sql.client import Connection

logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -51,20 +52,19 @@ def test_sea_sync_query_with_cloud_fetch():
)

# Execute a query that generates large rows to force multiple chunks
requested_row_count = 10000
requested_row_count = 100000000
cursor = connection.cursor()
query = f"""
SELECT
id,
concat('value_', repeat('a', 10000)) as test_value
FROM range(1, {requested_row_count} + 1) AS t(id)
SELECT * FROM samples.tpch.lineitem LIMIT {requested_row_count}
"""

logger.info(
f"Executing synchronous query with cloud fetch to generate {requested_row_count} rows"
)
cursor.execute(query)
results = [cursor.fetchone()]
logger.info("SLEEPING FOR 1000 SECONDS TO EXPIRE LINKS")
time.sleep(1000)
results.extend(cursor.fetchmany(10))
results.extend(cursor.fetchall())
actual_row_count = len(results)
Expand Down
24 changes: 24 additions & 0 deletions src/databricks/sql/backend/sea/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,24 @@ def _worker_loop(self):
if not links_downloaded:
break

def _restart_from_expired_link(self, link: TSparkArrowResultLink):
self.stop()

with self._link_data_update:
self.download_manager.cancel_tasks_from_offset(link.startRowOffset)

chunks_to_restart = []
for chunk_index, l in self.chunk_index_to_link.items():
if l.row_offset < link.startRowOffset:
continue
chunks_to_restart.append(chunk_index)
for chunk_index in chunks_to_restart:
self.chunk_index_to_link.pop(chunk_index)

self.start()

def start(self):
self._shutdown_event.clear()
self._worker_thread = threading.Thread(target=self._worker_loop)
self._worker_thread.start()

Expand Down Expand Up @@ -269,6 +286,7 @@ def __init__(
max_download_threads=max_download_threads,
lz4_compressed=lz4_compressed,
ssl_options=ssl_options,
expiry_callback=self._expiry_callback,
)

self.link_fetcher = LinkFetcher(
Expand All @@ -283,6 +301,12 @@ def __init__(
# Initialize table and position
self.table = self._create_next_table()

def _expiry_callback(self, link: TSparkArrowResultLink):
logger.info(
f"SeaCloudFetchQueue: Link expired, restarting from offset {link.startRowOffset}"
)
self.link_fetcher._restart_from_expired_link(link)

def _create_next_table(self) -> Union["pyarrow.Table", None]:
"""Create next table by retrieving the logical next downloaded file."""
if not self.download_manager:
Expand Down
74 changes: 66 additions & 8 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from concurrent.futures import ThreadPoolExecutor, Future
from typing import List, Union
from typing import Callable, List, Optional, Union, Generic, TypeVar

from databricks.sql.cloudfetch.downloader import (
ResultSetDownloadHandler,
Expand All @@ -14,6 +14,28 @@

logger = logging.getLogger(__name__)

T = TypeVar('T')


class TaskWithMetadata(Generic[T]):
"""
Wrapper around Future that stores additional metadata (the link).
Provides type-safe access to both the Future result and the associated link.
"""

def __init__(self, future: Future[T], link: TSparkArrowResultLink):
self.future = future
self.link = link

def result(self, timeout: Optional[float] = None) -> T:
"""Get the result of the Future, blocking if necessary."""
return self.future.result(timeout)

def cancel(self) -> bool:
"""Cancel the Future if possible."""
return self.future.cancel()



class ResultFileDownloadManager:
def __init__(
Expand All @@ -22,6 +44,7 @@ def __init__(
max_download_threads: int,
lz4_compressed: bool,
ssl_options: SSLOptions,
expiry_callback: Optional[Callable[[TSparkArrowResultLink], None]] = None,
):
self._pending_links: List[TSparkArrowResultLink] = []
for link in links:
Expand All @@ -34,12 +57,13 @@ def __init__(
)
self._pending_links.append(link)

self._download_tasks: List[Future[DownloadedFile]] = []
self._download_tasks: List[TaskWithMetadata[DownloadedFile]] = []
self._max_download_threads: int = max_download_threads
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)

self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self._ssl_options = ssl_options
self._expiry_callback = expiry_callback

def get_next_downloaded_file(
self, next_row_offset: int
Expand All @@ -51,7 +75,7 @@ def get_next_downloaded_file(
in relation to the full result. File downloads are scheduled if not already, and once the correct
download handler is located, the function waits for the download status and returns the resulting file.
If there are no more downloads, a download was not successful, or the correct file could not be located,
this function shuts down the thread pool and returns None.
this function returns None.

Args:
next_row_offset (int): The offset of the starting row of the next file we want data from.
Expand All @@ -62,7 +86,6 @@ def get_next_downloaded_file(

# No more files to download from this batch of links
if len(self._download_tasks) == 0:
self._shutdown_manager()
return None

task = self._download_tasks.pop(0)
Expand All @@ -81,6 +104,41 @@ def get_next_downloaded_file(

return file

def cancel_tasks_from_offset(self, start_row_offset: int):
"""
Cancel all download tasks starting from a specific row offset.
This is used when links expire and we need to restart from a certain point.

Args:
start_row_offset (int): Row offset from which to cancel tasks
"""

def to_cancel(link: TSparkArrowResultLink) -> bool:
return link.startRowOffset < start_row_offset

tasks_to_cancel = [
task for task in self._download_tasks if to_cancel(task.link)
]
for task in tasks_to_cancel:
task.cancel()
logger.info(
f"ResultFileDownloadManager: cancelled {len(tasks_to_cancel)} tasks from offset {start_row_offset}"
)

# Remove cancelled tasks from the download queue
tasks_to_keep = [
task for task in self._download_tasks if not to_cancel(task.link)
]
self._download_tasks = tasks_to_keep

pending_links_to_keep = [
link for link in self._pending_links if not to_cancel(link)
]
self._pending_links = pending_links_to_keep
logger.info(
f"ResultFileDownloadManager: removed {len(self._pending_links) - len(pending_links_to_keep)} links from pending links"
)

def _schedule_downloads(self):
"""
While download queue has a capacity, peek pending links and submit them to thread pool.
Expand All @@ -97,21 +155,21 @@ def _schedule_downloads(self):
settings=self._downloadable_result_settings,
link=link,
ssl_options=self._ssl_options,
expiry_callback=self._expiry_callback,
)
task = self._thread_pool.submit(handler.run)
future = self._thread_pool.submit(handler.run)
task = TaskWithMetadata(future, link)
self._download_tasks.append(task)

def add_link(self, link: TSparkArrowResultLink):
"""
Add more links to the download manager.

Args:
link: Link to add
link (TSparkArrowResultLink): The link to add to the download manager.
"""

if link.rowCount <= 0:
return

logger.debug(
"ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format(
link.startRowOffset, link.rowCount
Expand Down
12 changes: 6 additions & 6 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Callable

import requests
from requests.adapters import HTTPAdapter, Retry
Expand Down Expand Up @@ -66,10 +67,12 @@ def __init__(
settings: DownloadableResultSettings,
link: TSparkArrowResultLink,
ssl_options: SSLOptions,
expiry_callback: Callable[[TSparkArrowResultLink], None],
):
self.settings = settings
self.link = link
self._ssl_options = ssl_options
self._expiry_callback = expiry_callback

def run(self) -> DownloadedFile:
"""
Expand All @@ -86,9 +89,7 @@ def run(self) -> DownloadedFile:
)

# Check if link is already expired or is expiring
ResultSetDownloadHandler._validate_link(
self.link, self.settings.link_expiry_buffer_secs
)
self._validate_link(self.link, self.settings.link_expiry_buffer_secs)

session = requests.Session()
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
Expand Down Expand Up @@ -136,8 +137,7 @@ def run(self) -> DownloadedFile:
if session:
session.close()

@staticmethod
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
def _validate_link(self, link: TSparkArrowResultLink, expiry_buffer_secs: int):
"""
Check if a link has expired or will expire.

Expand All @@ -149,7 +149,7 @@ def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
link.expiryTime <= current_time
or link.expiryTime - current_time <= expiry_buffer_secs
):
raise Error("CloudFetch link has expired")
self._expiry_callback(link)

@staticmethod
def _decompress_data(compressed_data: bytes) -> bytes:
Expand Down
6 changes: 6 additions & 0 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import lz4.frame

from databricks.sql.exc import Error

try:
import pyarrow
except ImportError:
Expand Down Expand Up @@ -374,12 +376,16 @@ def __init__(
)
)

def expiry_callback(link: TSparkArrowResultLink):
raise Error("Cloudfetch link has expired")

# Initialize download manager
self.download_manager = ResultFileDownloadManager(
links=self.result_links,
max_download_threads=self.max_download_threads,
lz4_compressed=self.lz4_compressed,
ssl_options=self._ssl_options,
expiry_callback=expiry_callback,
)

# Initialize table and position
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/test_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ class DownloadManagerTests(unittest.TestCase):
def create_download_manager(
self, links, max_download_threads=10, lz4_compressed=True
):
def expiry_callback(link: TSparkArrowResultLink):
return None

return download_manager.ResultFileDownloadManager(
links,
max_download_threads,
lz4_compressed,
ssl_options=SSLOptions(),
expiry_callback=expiry_callback,
)

def create_result_link(
Expand Down
Loading
Loading