Skip to content

Commit 76eed47

Browse files
authored
add wrap_file for wrapping a file object with callback (#271)
add wrap_file for wrapping a file object with contextmanager
1 parent cd89d9f commit 76eed47

File tree

4 files changed

+33
-25
lines changed

4 files changed

+33
-25
lines changed

src/dvc_objects/fs/base.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Sequence,
2222
Tuple,
2323
Union,
24-
cast,
2524
overload,
2625
)
2726
from urllib.parse import urlsplit, urlunsplit
@@ -34,8 +33,8 @@
3433

3534
from .callbacks import (
3635
DEFAULT_CALLBACK,
37-
CallbackStream,
3836
wrap_and_branch_callback,
37+
wrap_file,
3938
)
4039
from .errors import RemoteMissingDepsError
4140

@@ -637,7 +636,7 @@ def put_file(
637636
if size:
638637
callback.set_size(size)
639638
if hasattr(from_file, "read"):
640-
stream = cast("BinaryIO", CallbackStream(from_file, callback))
639+
stream = wrap_file(from_file, callback)
641640
self.upload_fobj(stream, to_info, size=size)
642641
else:
643642
assert isinstance(from_file, str)

src/dvc_objects/fs/callbacks.py

+13-19
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,28 @@
11
import asyncio
22
from functools import wraps
3-
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, TypeVar
3+
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Optional, TypeVar, cast
44

55
import fsspec
66

77
if TYPE_CHECKING:
8-
from typing import BinaryIO, Union
8+
from typing import Union
99

1010
from dvc_objects._tqdm import Tqdm
1111

1212
F = TypeVar("F", bound=Callable)
1313

1414

1515
class CallbackStream:
16-
def __init__(self, stream, callback, method="read"):
16+
def __init__(self, stream, callback: fsspec.Callback):
1717
self.stream = stream
18-
if method == "write":
1918

20-
@wraps(stream.write)
21-
def write(data, *args, **kwargs):
22-
res = stream.write(data, *args, **kwargs)
23-
callback.relative_update(len(data))
24-
return res
19+
@wraps(stream.read)
20+
def read(*args, **kwargs):
21+
data = stream.read(*args, **kwargs)
22+
callback.relative_update(len(data))
23+
return data
2524

26-
self.write = write
27-
else:
28-
29-
@wraps(stream.read)
30-
def read(*args, **kwargs):
31-
data = stream.read(*args, **kwargs)
32-
callback.relative_update(len(data))
33-
return data
34-
35-
self.read = read
25+
self.read = read
3626

3727
def __getattr__(self, attr):
3828
return getattr(self.stream, attr)
@@ -181,4 +171,8 @@ def wrap_and_branch_callback(callback: fsspec.Callback, fn: F) -> F:
181171
return wrap_fn(callback, branch_wrapper)
182172

183173

174+
def wrap_file(file, callback: fsspec.Callback) -> BinaryIO:
175+
return cast(BinaryIO, CallbackStream(file, callback))
176+
177+
184178
DEFAULT_CALLBACK = NoOpCallback()

src/dvc_objects/fs/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dvc_objects.executors import ThreadPoolExecutor
1414

1515
from . import system
16-
from .callbacks import DEFAULT_CALLBACK, CallbackStream
16+
from .callbacks import DEFAULT_CALLBACK, wrap_file
1717

1818
if TYPE_CHECKING:
1919
from .base import AnyFSPath, FileSystem
@@ -168,8 +168,8 @@ def copyfile(
168168

169169
callback.set_size(total)
170170
with open(src, "rb") as fsrc, open(dest, "wb+") as fdest:
171-
wrapped = CallbackStream(fdest, callback, "write")
172-
shutil.copyfileobj(fsrc, wrapped, length=LOCAL_CHUNK_SIZE)
171+
wrapped = wrap_file(fsrc, callback)
172+
shutil.copyfileobj(wrapped, fdest, length=LOCAL_CHUNK_SIZE)
173173

174174

175175
def tmp_fname(prefix: str = "") -> str:

tests/fs/test_callbacks.py

+15
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
TqdmCallback,
1010
branch_callback,
1111
wrap_and_branch_callback,
12+
wrap_file,
1213
wrap_fn,
1314
)
1415

@@ -146,3 +147,17 @@ async def test_wrap_and_branch_callback_async(mocker, cb_class):
146147
m.assert_any_call("argA", "argB", arg3="argC", callback=IsDVCCallback())
147148
assert callback.value == 2
148149
assert spy.call_count == 2
150+
151+
152+
def test_wrap_file(memfs):
153+
memfs.pipe_file("/file", b"foo\n")
154+
155+
callback = Callback()
156+
157+
callback.set_size(4)
158+
with memfs.open("/file", mode="rb") as f:
159+
wrapped = wrap_file(f, callback)
160+
assert wrapped.read() == b"foo\n"
161+
162+
assert callback.value == 4
163+
assert callback.size == 4

0 commit comments

Comments
 (0)