Skip to content

Commit 6984294

Browse files
authored
feature(store): add LoggingStore wrapper (zarr-developers#2231)
* feature(store): add LoggingStore wrapper * add counter * lint
1 parent 3365928 commit 6984294

File tree

2 files changed

+212
-0
lines changed

2 files changed

+212
-0
lines changed

src/zarr/store/logging.py

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
import logging
5+
import time
6+
from collections import defaultdict
7+
from contextlib import contextmanager
8+
from typing import TYPE_CHECKING
9+
10+
from zarr.abc.store import AccessMode, Store
11+
12+
if TYPE_CHECKING:
13+
from collections.abc import AsyncGenerator, Generator
14+
15+
from zarr.core.buffer import Buffer, BufferPrototype
16+
17+
18+
class LoggingStore(Store):
19+
_store: Store
20+
counter: defaultdict[str, int]
21+
22+
def __init__(
23+
self,
24+
store: Store,
25+
log_level: str = "DEBUG",
26+
log_handler: logging.Handler | None = None,
27+
):
28+
self._store = store
29+
self.counter = defaultdict(int)
30+
31+
self._configure_logger(log_level, log_handler)
32+
33+
def _configure_logger(
34+
self, log_level: str = "DEBUG", log_handler: logging.Handler | None = None
35+
) -> None:
36+
self.log_level = log_level
37+
self.logger = logging.getLogger(f"LoggingStore({self._store!s})")
38+
self.logger.setLevel(log_level)
39+
40+
if not self.logger.hasHandlers():
41+
if not log_handler:
42+
log_handler = self._default_handler()
43+
# Add handler to logger
44+
self.logger.addHandler(log_handler)
45+
46+
def _default_handler(self) -> logging.Handler:
47+
"""Define a default log handler"""
48+
handler = logging.StreamHandler()
49+
handler.setLevel(self.log_level)
50+
handler.setFormatter(
51+
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
52+
)
53+
return handler
54+
55+
@contextmanager
56+
def log(self) -> Generator[None, None, None]:
57+
method = inspect.stack()[2].function
58+
op = f"{type(self._store).__name__}.{method}"
59+
self.logger.info(f"Calling {op}")
60+
start_time = time.time()
61+
try:
62+
self.counter[method] += 1
63+
yield
64+
finally:
65+
end_time = time.time()
66+
self.logger.info(f"Finished {op} in {end_time - start_time:.2f} seconds")
67+
68+
@property
69+
def supports_writes(self) -> bool:
70+
with self.log():
71+
return self._store.supports_writes
72+
73+
@property
74+
def supports_deletes(self) -> bool:
75+
with self.log():
76+
return self._store.supports_deletes
77+
78+
@property
79+
def supports_partial_writes(self) -> bool:
80+
with self.log():
81+
return self._store.supports_partial_writes
82+
83+
@property
84+
def supports_listing(self) -> bool:
85+
with self.log():
86+
return self._store.supports_listing
87+
88+
@property
89+
def _mode(self) -> AccessMode: # type: ignore[override]
90+
with self.log():
91+
return self._store._mode
92+
93+
@property
94+
def _is_open(self) -> bool: # type: ignore[override]
95+
with self.log():
96+
return self._store._is_open
97+
98+
async def empty(self) -> bool:
99+
with self.log():
100+
return await self._store.empty()
101+
102+
async def clear(self) -> None:
103+
with self.log():
104+
return await self._store.clear()
105+
106+
def __str__(self) -> str:
107+
return f"logging-{self._store!s}"
108+
109+
def __repr__(self) -> str:
110+
return f"LoggingStore({repr(self._store)!r})"
111+
112+
def __eq__(self, other: object) -> bool:
113+
with self.log():
114+
return self._store == other
115+
116+
async def get(
117+
self,
118+
key: str,
119+
prototype: BufferPrototype,
120+
byte_range: tuple[int | None, int | None] | None = None,
121+
) -> Buffer | None:
122+
with self.log():
123+
return await self._store.get(key=key, prototype=prototype, byte_range=byte_range)
124+
125+
async def get_partial_values(
126+
self,
127+
prototype: BufferPrototype,
128+
key_ranges: list[tuple[str, tuple[int | None, int | None]]],
129+
) -> list[Buffer | None]:
130+
with self.log():
131+
return await self._store.get_partial_values(prototype=prototype, key_ranges=key_ranges)
132+
133+
async def exists(self, key: str) -> bool:
134+
with self.log():
135+
return await self._store.exists(key)
136+
137+
async def set(self, key: str, value: Buffer) -> None:
138+
with self.log():
139+
return await self._store.set(key=key, value=value)
140+
141+
async def delete(self, key: str) -> None:
142+
with self.log():
143+
return await self._store.delete(key=key)
144+
145+
async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None:
146+
with self.log():
147+
return await self._store.set_partial_values(key_start_values=key_start_values)
148+
149+
async def list(self) -> AsyncGenerator[str, None]:
150+
with self.log():
151+
async for key in self._store.list():
152+
yield key
153+
154+
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
155+
with self.log():
156+
async for key in self._store.list_prefix(prefix=prefix):
157+
yield key
158+
159+
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
160+
with self.log():
161+
async for key in self._store.list_dir(prefix=prefix):
162+
yield key

tests/v3/test_store/test_logging.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import pytest
6+
7+
import zarr
8+
from zarr.core.buffer import default_buffer_prototype
9+
from zarr.store.logging import LoggingStore
10+
11+
if TYPE_CHECKING:
12+
from zarr.abc.store import Store
13+
14+
15+
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
16+
async def test_logging_store(store: Store, caplog) -> None:
17+
wrapped = LoggingStore(store=store, log_level="DEBUG")
18+
buffer = default_buffer_prototype().buffer
19+
20+
caplog.clear()
21+
res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04"))
22+
assert res is None
23+
assert len(caplog.record_tuples) == 2
24+
for tup in caplog.record_tuples:
25+
assert str(store) in tup[0]
26+
assert f"Calling {type(store).__name__}.set" in caplog.record_tuples[0][2]
27+
assert f"Finished {type(store).__name__}.set" in caplog.record_tuples[1][2]
28+
29+
caplog.clear()
30+
keys = [k async for k in wrapped.list()]
31+
assert keys == ["foo/bar/c/0"]
32+
assert len(caplog.record_tuples) == 2
33+
for tup in caplog.record_tuples:
34+
assert str(store) in tup[0]
35+
assert f"Calling {type(store).__name__}.list" in caplog.record_tuples[0][2]
36+
assert f"Finished {type(store).__name__}.list" in caplog.record_tuples[1][2]
37+
38+
39+
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
40+
async def test_logging_store_counter(store: Store) -> None:
41+
wrapped = LoggingStore(store=store, log_level="DEBUG")
42+
43+
arr = zarr.create(shape=(10,), store=wrapped, overwrite=True)
44+
arr[:] = 1
45+
46+
assert wrapped.counter["set"] == 2
47+
assert wrapped.counter["get"] == 0 # 1 if overwrite=False
48+
assert wrapped.counter["list"] == 0
49+
assert wrapped.counter["list_dir"] == 0
50+
assert wrapped.counter["list_prefix"] == 0

0 commit comments

Comments
 (0)