Skip to content

Commit

Permalink
dev(narugo): add resume for ranged headers of http_get function (#2823)
Browse files Browse the repository at this point in the history
* dev(narugo): add resume support for ranged downloading for function http_get

* dev(narugo): add unittest cases

* rework http_get tests

* style

* dev(narugo): unittest for adjust_range_header passed

* dev(narugo): add docstring for adjust_range_header function

* dev(narugo): migration completed

* additional refactoring

* fix typing

* dev(narugo): move the regex compilation out from the function

* dev(narugo): raise error on every unexpected branches

* Update file_download.py (should fix mypy)

---------

Co-authored-by: Celina Hanouti <[email protected]>
Co-authored-by: Lucain <[email protected]>
  • Loading branch information
3 people authored Feb 11, 2025
1 parent fc491d4 commit e5c84bc
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 18 deletions.
9 changes: 5 additions & 4 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
tqdm,
validate_hf_hub_args,
)
from .utils._http import _adjust_range_header
from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility
from .utils._typing import HTTP_METHOD_T
from .utils.sha import sha_fileobj
Expand Down Expand Up @@ -309,8 +310,8 @@ def http_get(
temp_file: BinaryIO,
*,
proxies: Optional[Dict] = None,
resume_size: float = 0,
headers: Optional[Dict[str, str]] = None,
resume_size: int = 0,
headers: Optional[Dict[str, Any]] = None,
expected_size: Optional[int] = None,
displayed_filename: Optional[str] = None,
_nb_retries: int = 5,
Expand All @@ -330,7 +331,7 @@ def http_get(
The file-like object where to save the file.
proxies (`dict`, *optional*):
Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
resume_size (`float`, *optional*):
resume_size (`int`, *optional*):
The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a
positive number, the download will resume at the given position.
headers (`dict`, *optional*):
Expand Down Expand Up @@ -365,7 +366,7 @@ def http_get(
initial_headers = headers
headers = copy.deepcopy(headers) or {}
if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,)
headers["Range"] = _adjust_range_header(headers.get("Range"), resume_size)

r = _request_wrapper(
method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT
Expand Down
41 changes: 41 additions & 0 deletions src/huggingface_hub/utils/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,44 @@ def _curlify(request: requests.PreparedRequest) -> str:
flat_parts.append(quote(v))

return " ".join(flat_parts)


# Regex to parse HTTP Range header
RANGE_REGEX = re.compile(r"^\s*bytes\s*=\s*(\d*)\s*-\s*(\d*)\s*$", re.IGNORECASE)


def _adjust_range_header(original_range: Optional[str], resume_size: int) -> Optional[str]:
"""
Adjust HTTP Range header to account for resume position.
"""
if not original_range:
return f"bytes={resume_size}-"

if "," in original_range:
raise ValueError(f"Multiple ranges detected - {original_range!r}, not supported yet.")

match = RANGE_REGEX.match(original_range)
if not match:
raise RuntimeError(f"Invalid range format - {original_range!r}.")
start, end = match.groups()

if not start:
if not end:
raise RuntimeError(f"Invalid range format - {original_range!r}.")

new_suffix = int(end) - resume_size
new_range = f"bytes=-{new_suffix}"
if new_suffix <= 0:
raise RuntimeError(f"Empty new range - {new_range!r}.")
return new_range

start = int(start)
new_start = start + resume_size
if end:
end = int(end)
new_range = f"bytes={new_start}-{end}"
if new_start > end:
raise RuntimeError(f"Empty new range - {new_range!r}.")
return new_range

return f"bytes={new_start}-"
105 changes: 91 additions & 14 deletions tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Iterable
from typing import Iterable, List
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -932,8 +932,8 @@ def test_get_pointer_path_but_invalid_relative_filename(self) -> None:
_get_pointer_path("path/to/storage", "abcdef", relative_filename)


class TestHttpGet(unittest.TestCase):
def test_http_get_with_ssl_and_timeout_error(self):
class TestHttpGet:
def test_http_get_with_ssl_and_timeout_error(self, caplog):
def _iter_content_1() -> Iterable[bytes]:
yield b"0" * 10
yield b"0" * 10
Expand Down Expand Up @@ -966,22 +966,99 @@ def _iter_content_4() -> Iterable[bytes]:

temp_file = io.BytesIO()

with self.assertLogs("huggingface_hub.file_download", level="WARNING") as records:
http_get("fake_url", temp_file=temp_file)
http_get("fake_url", temp_file=temp_file)

# Check 3 warnings
self.assertEqual(len(records.records), 3)
assert len([r for r in caplog.records if r.levelname == "WARNING"]) == 3

# Check final value
self.assertEqual(temp_file.tell(), 100)
self.assertEqual(temp_file.getvalue(), b"0" * 100)
assert temp_file.tell() == 100
assert temp_file.getvalue() == b"0" * 100

# Check number of calls + correct range headers
self.assertEqual(len(mock.call_args_list), 4)
self.assertEqual(mock.call_args_list[0].kwargs["headers"], {})
self.assertEqual(mock.call_args_list[1].kwargs["headers"], {"Range": "bytes=20-"})
self.assertEqual(mock.call_args_list[2].kwargs["headers"], {"Range": "bytes=30-"})
self.assertEqual(mock.call_args_list[3].kwargs["headers"], {"Range": "bytes=60-"})
assert len(mock.call_args_list) == 4
assert mock.call_args_list[0].kwargs["headers"] == {}
assert mock.call_args_list[1].kwargs["headers"] == {"Range": "bytes=20-"}
assert mock.call_args_list[2].kwargs["headers"] == {"Range": "bytes=30-"}
assert mock.call_args_list[3].kwargs["headers"] == {"Range": "bytes=60-"}

@pytest.mark.parametrize(
"initial_range,expected_ranges",
[
# Test suffix ranges (bytes=-100)
(
"bytes=-100",
[
"bytes=-100",
"bytes=-80",
"bytes=-70",
"bytes=-40",
],
),
# Test prefix ranges (bytes=15-)
(
"bytes=15-",
[
"bytes=15-",
"bytes=35-",
"bytes=45-",
"bytes=75-",
],
),
# Test double closed ranges (bytes=15-114)
(
"bytes=15-114",
[
"bytes=15-114",
"bytes=35-114",
"bytes=45-114",
"bytes=75-114",
],
),
],
)
def test_http_get_with_range_headers(self, caplog, initial_range: str, expected_ranges: List[str]):
def _iter_content_1() -> Iterable[bytes]:
yield b"0" * 10
yield b"0" * 10
raise requests.exceptions.SSLError("Fake SSLError")

def _iter_content_2() -> Iterable[bytes]:
yield b"0" * 10
raise requests.ReadTimeout("Fake ReadTimeout")

def _iter_content_3() -> Iterable[bytes]:
yield b"0" * 10
yield b"0" * 10
yield b"0" * 10
raise requests.ConnectionError("Fake ConnectionError")

def _iter_content_4() -> Iterable[bytes]:
yield b"0" * 10
yield b"0" * 10
yield b"0" * 10
yield b"0" * 10

with patch("huggingface_hub.file_download._request_wrapper") as mock:
mock.return_value.headers = {"Content-Length": 100}
mock.return_value.iter_content.side_effect = [
_iter_content_1(),
_iter_content_2(),
_iter_content_3(),
_iter_content_4(),
]

temp_file = io.BytesIO()

http_get("fake_url", temp_file=temp_file, headers={"Range": initial_range})

assert len([r for r in caplog.records if r.levelname == "WARNING"]) == 3

assert temp_file.tell() == 100
assert temp_file.getvalue() == b"0" * 100

assert len(mock.call_args_list) == 4
for i, expected_range in enumerate(expected_ranges):
assert mock.call_args_list[i].kwargs["headers"] == {"Range": expected_range}


class CreateSymlinkTest(unittest.TestCase):
Expand Down
25 changes: 25 additions & 0 deletions tests/test_utils_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from huggingface_hub.constants import ENDPOINT
from huggingface_hub.utils._http import (
OfflineModeIsEnabled,
_adjust_range_header,
configure_http_backend,
fix_hf_endpoint_in_url,
get_session,
Expand Down Expand Up @@ -311,3 +312,27 @@ def _is_uuid(string: str) -> bool:
)
def test_fix_hf_endpoint_in_url(base_url: str, endpoint: Optional[str], expected_url: str) -> None:
assert fix_hf_endpoint_in_url(base_url, endpoint) == expected_url


def test_adjust_range_header():
# Basic cases
assert _adjust_range_header(None, 10) == "bytes=10-"
assert _adjust_range_header("bytes=0-100", 10) == "bytes=10-100"
assert _adjust_range_header("bytes=-100", 10) == "bytes=-90"
assert _adjust_range_header("bytes=100-", 10) == "bytes=110-"

with pytest.raises(RuntimeError):
_adjust_range_header("invalid", 10)

with pytest.raises(RuntimeError):
_adjust_range_header("bytes=-", 10)

# Multiple ranges
with pytest.raises(ValueError):
_adjust_range_header("bytes=0-100,200-300", 10)

# Resume size exceeds range
with pytest.raises(RuntimeError):
_adjust_range_header("bytes=0-100", 150)
with pytest.raises(RuntimeError):
_adjust_range_header("bytes=-50", 100)

0 comments on commit e5c84bc

Please sign in to comment.