diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index f96d663d80..e9f3d9fba7 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -60,6 +60,7 @@ tqdm, validate_hf_hub_args, ) +from .utils._http import _adjust_range_header from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility from .utils._typing import HTTP_METHOD_T from .utils.sha import sha_fileobj @@ -309,8 +310,8 @@ def http_get( temp_file: BinaryIO, *, proxies: Optional[Dict] = None, - resume_size: float = 0, - headers: Optional[Dict[str, str]] = None, + resume_size: int = 0, + headers: Optional[Dict[str, Any]] = None, expected_size: Optional[int] = None, displayed_filename: Optional[str] = None, _nb_retries: int = 5, @@ -330,7 +331,7 @@ def http_get( The file-like object where to save the file. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. - resume_size (`float`, *optional*): + resume_size (`int`, *optional*): The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a positive number, the download will resume at the given position. headers (`dict`, *optional*): @@ -365,7 +366,7 @@ def http_get( initial_headers = headers headers = copy.deepcopy(headers) or {} if resume_size > 0: - headers["Range"] = "bytes=%d-" % (resume_size,) + headers["Range"] = _adjust_range_header(headers.get("Range"), resume_size) r = _request_wrapper( method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT diff --git a/src/huggingface_hub/utils/_http.py b/src/huggingface_hub/utils/_http.py index 5f3f590441..0ed0789dc3 100644 --- a/src/huggingface_hub/utils/_http.py +++ b/src/huggingface_hub/utils/_http.py @@ -591,3 +591,44 @@ def _curlify(request: requests.PreparedRequest) -> str: flat_parts.append(quote(v)) return " ".join(flat_parts) + + +# Regex to parse HTTP Range header +RANGE_REGEX = re.compile(r"^\s*bytes\s*=\s*(\d*)\s*-\s*(\d*)\s*$", re.IGNORECASE) + + +def _adjust_range_header(original_range: Optional[str], resume_size: int) -> Optional[str]: + """ + Adjust HTTP Range header to account for resume position. + """ + if not original_range: + return f"bytes={resume_size}-" + + if "," in original_range: + raise ValueError(f"Multiple ranges detected - {original_range!r}, not supported yet.") + + match = RANGE_REGEX.match(original_range) + if not match: + raise RuntimeError(f"Invalid range format - {original_range!r}.") + start, end = match.groups() + + if not start: + if not end: + raise RuntimeError(f"Invalid range format - {original_range!r}.") + + new_suffix = int(end) - resume_size + new_range = f"bytes=-{new_suffix}" + if new_suffix <= 0: + raise RuntimeError(f"Empty new range - {new_range!r}.") + return new_range + + start = int(start) + new_start = start + resume_size + if end: + end = int(end) + new_range = f"bytes={new_start}-{end}" + if new_start > end: + raise RuntimeError(f"Empty new range - {new_range!r}.") + return new_range + + return f"bytes={new_start}-" diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 6b83261cf3..f20794e241 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -19,7 +19,7 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Iterable +from typing import Iterable, List from unittest.mock import Mock, patch import pytest @@ -932,8 +932,8 @@ def test_get_pointer_path_but_invalid_relative_filename(self) -> None: _get_pointer_path("path/to/storage", "abcdef", relative_filename) -class TestHttpGet(unittest.TestCase): - def test_http_get_with_ssl_and_timeout_error(self): +class TestHttpGet: + def test_http_get_with_ssl_and_timeout_error(self, caplog): def _iter_content_1() -> Iterable[bytes]: yield b"0" * 10 yield b"0" * 10 @@ -966,22 +966,99 @@ def _iter_content_4() -> Iterable[bytes]: temp_file = io.BytesIO() - with self.assertLogs("huggingface_hub.file_download", level="WARNING") as records: - http_get("fake_url", temp_file=temp_file) + http_get("fake_url", temp_file=temp_file) - # Check 3 warnings - self.assertEqual(len(records.records), 3) + assert len([r for r in caplog.records if r.levelname == "WARNING"]) == 3 # Check final value - self.assertEqual(temp_file.tell(), 100) - self.assertEqual(temp_file.getvalue(), b"0" * 100) + assert temp_file.tell() == 100 + assert temp_file.getvalue() == b"0" * 100 # Check number of calls + correct range headers - self.assertEqual(len(mock.call_args_list), 4) - self.assertEqual(mock.call_args_list[0].kwargs["headers"], {}) - self.assertEqual(mock.call_args_list[1].kwargs["headers"], {"Range": "bytes=20-"}) - self.assertEqual(mock.call_args_list[2].kwargs["headers"], {"Range": "bytes=30-"}) - self.assertEqual(mock.call_args_list[3].kwargs["headers"], {"Range": "bytes=60-"}) + assert len(mock.call_args_list) == 4 + assert mock.call_args_list[0].kwargs["headers"] == {} + assert mock.call_args_list[1].kwargs["headers"] == {"Range": "bytes=20-"} + assert mock.call_args_list[2].kwargs["headers"] == {"Range": "bytes=30-"} + assert mock.call_args_list[3].kwargs["headers"] == {"Range": "bytes=60-"} + + @pytest.mark.parametrize( + "initial_range,expected_ranges", + [ + # Test suffix ranges (bytes=-100) + ( + "bytes=-100", + [ + "bytes=-100", + "bytes=-80", + "bytes=-70", + "bytes=-40", + ], + ), + # Test prefix ranges (bytes=15-) + ( + "bytes=15-", + [ + "bytes=15-", + "bytes=35-", + "bytes=45-", + "bytes=75-", + ], + ), + # Test double closed ranges (bytes=15-114) + ( + "bytes=15-114", + [ + "bytes=15-114", + "bytes=35-114", + "bytes=45-114", + "bytes=75-114", + ], + ), + ], + ) + def test_http_get_with_range_headers(self, caplog, initial_range: str, expected_ranges: List[str]): + def _iter_content_1() -> Iterable[bytes]: + yield b"0" * 10 + yield b"0" * 10 + raise requests.exceptions.SSLError("Fake SSLError") + + def _iter_content_2() -> Iterable[bytes]: + yield b"0" * 10 + raise requests.ReadTimeout("Fake ReadTimeout") + + def _iter_content_3() -> Iterable[bytes]: + yield b"0" * 10 + yield b"0" * 10 + yield b"0" * 10 + raise requests.ConnectionError("Fake ConnectionError") + + def _iter_content_4() -> Iterable[bytes]: + yield b"0" * 10 + yield b"0" * 10 + yield b"0" * 10 + yield b"0" * 10 + + with patch("huggingface_hub.file_download._request_wrapper") as mock: + mock.return_value.headers = {"Content-Length": 100} + mock.return_value.iter_content.side_effect = [ + _iter_content_1(), + _iter_content_2(), + _iter_content_3(), + _iter_content_4(), + ] + + temp_file = io.BytesIO() + + http_get("fake_url", temp_file=temp_file, headers={"Range": initial_range}) + + assert len([r for r in caplog.records if r.levelname == "WARNING"]) == 3 + + assert temp_file.tell() == 100 + assert temp_file.getvalue() == b"0" * 100 + + assert len(mock.call_args_list) == 4 + for i, expected_range in enumerate(expected_ranges): + assert mock.call_args_list[i].kwargs["headers"] == {"Range": expected_range} class CreateSymlinkTest(unittest.TestCase): diff --git a/tests/test_utils_http.py b/tests/test_utils_http.py index 71262aa1e8..07037e6aba 100644 --- a/tests/test_utils_http.py +++ b/tests/test_utils_http.py @@ -14,6 +14,7 @@ from huggingface_hub.constants import ENDPOINT from huggingface_hub.utils._http import ( OfflineModeIsEnabled, + _adjust_range_header, configure_http_backend, fix_hf_endpoint_in_url, get_session, @@ -311,3 +312,27 @@ def _is_uuid(string: str) -> bool: ) def test_fix_hf_endpoint_in_url(base_url: str, endpoint: Optional[str], expected_url: str) -> None: assert fix_hf_endpoint_in_url(base_url, endpoint) == expected_url + + +def test_adjust_range_header(): + # Basic cases + assert _adjust_range_header(None, 10) == "bytes=10-" + assert _adjust_range_header("bytes=0-100", 10) == "bytes=10-100" + assert _adjust_range_header("bytes=-100", 10) == "bytes=-90" + assert _adjust_range_header("bytes=100-", 10) == "bytes=110-" + + with pytest.raises(RuntimeError): + _adjust_range_header("invalid", 10) + + with pytest.raises(RuntimeError): + _adjust_range_header("bytes=-", 10) + + # Multiple ranges + with pytest.raises(ValueError): + _adjust_range_header("bytes=0-100,200-300", 10) + + # Resume size exceeds range + with pytest.raises(RuntimeError): + _adjust_range_header("bytes=0-100", 150) + with pytest.raises(RuntimeError): + _adjust_range_header("bytes=-50", 100)