Skip to content

Commit

Permalink
update nested refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Nov 8, 2024
1 parent c9fff3b commit ff7d89b
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 9 deletions.
70 changes: 70 additions & 0 deletions tests/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,73 @@ def test_dct_merge(client, request):
new_dct = dct | dct2
assert new_dct == {"a": "1", "b": "3", "c": "4"}
assert isinstance(new_dct, dict)


@pytest.mark.parametrize("client", ["znsclient", "znsclient_w_redis", "redisclient"])
def test_dict_nested_refresh(client, request, znsclient):
r = request.getfixturevalue(client)
dct = znsocket.Dict(r=r, key="dct:test", socket=znsclient)
dctA = znsocket.Dict(r=r, key="dct:test:A", socket=znsclient)
dctB = znsocket.Dict(r=r, key="dct:test:B", socket=znsclient)

dct2 = znsocket.Dict(
r=r, key="dct:test", socket=znsocket.Client.from_url(znsclient.address)
)
mock = MagicMock()
dct2.on_refresh(mock, nested=True)

dct["A"] = dctA
dct["B"] = dctB

znsclient.sio.sleep(0.2)

assert mock.call_count == 2

mock.reset_mock()

dctA["key"] = "value"

znsclient.sio.sleep(0.2)

assert mock.call_count == 1
assert mock.call_args[0][0] == {"keys": ["A"]}

mock.reset_mock()

dctB["lorem"] = "ipsum"

znsclient.sio.sleep(0.2)

assert mock.call_count == 1
assert mock.call_args[0][0] == {"keys": ["B"]}

mock.reset_mock()

assert dct["A"]["key"] == "value"
assert dct["B"]["lorem"] == "ipsum"

# now for good measure add a List and refresh that

lst = znsocket.List(r=r, key="lst:test", socket=znsclient)
lst.append("item")
dct["L"] = lst

znsclient.sio.sleep(0.2)

mock.reset_mock()

lst.append("item2")

znsclient.sio.sleep(0.2)
assert mock.call_count == 1
assert mock.call_args[0][0] == {"keys": ["L"]}

mock.reset_mock()
dct2.nested_refresh = False

dctA["key"] = "value2"
dctB["lorem"] = "ipsum2"

znsclient.sio.sleep(0.5)

assert mock.call_count == 0
71 changes: 71 additions & 0 deletions tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,74 @@ def test_list_refresh_extend_self_trigger(client, request, znsclient):
znsclient.sio.sleep(0.01)
assert len(lst) == 3
mock.assert_not_called()




@pytest.mark.parametrize("client", ["znsclient", "znsclient_w_redis", "redisclient"])
def test_list_nested_refresh(client, request, znsclient):
r = request.getfixturevalue(client)
lst = znsocket.List(r=r, key="lst:test", socket=znsclient)
dctA = znsocket.Dict(r=r, key="dct:test:A", socket=znsclient)
dctB = znsocket.Dict(r=r, key="dct:test:B", socket=znsclient)

lst2 = znsocket.List(
r=r, key="lst:test", socket=znsocket.Client.from_url(znsclient.address)
)
mock = MagicMock()
lst2.on_refresh(mock, nested=True)

lst.extend([dctA, dctB])

znsclient.sio.sleep(0.5)

assert mock.call_count == 1

mock.reset_mock()

dctA["key"] = "value"

znsclient.sio.sleep(0.2)

assert mock.call_count == 1
assert mock.call_args[0][0] == {"indices": [0]}

mock.reset_mock()

dctB["lorem"] = "ipsum"

znsclient.sio.sleep(0.2)

assert mock.call_count == 1
assert mock.call_args[0][0] =={"indices": [1]}

mock.reset_mock()

assert lst[0]["key"] == "value"
assert lst[1]["lorem"] == "ipsum"

# now for good measure add a list

lst3 = znsocket.List(r=r, key="lst:test:3", socket=znsclient)
lst.append(lst3)

znsclient.sio.sleep(0.2)
mock.reset_mock()

lst3.append("value")

znsclient.sio.sleep(0.2)

assert mock.call_count == 1
assert mock.call_args[0][0] == {"indices": [2]}

mock.reset_mock()

lst2.nested_refresh = False

dctA["key"] = "value2"
dctB["lorem"] = "ipsum2"

znsclient.sio.sleep(0.2)

assert mock.call_count == 0
10 changes: 5 additions & 5 deletions znsocket/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Client:
default_factory=socketio.Client, repr=False, init=False
)
namespace: str = "/znsocket"
refresh_callbacks: dict = dataclasses.field(default_factory=dict)
attachments: list = dataclasses.field(default_factory=list)

def pipeline(self, *args, **kwargs) -> "Pipeline":
return Pipeline(self, *args, **kwargs)
Expand All @@ -77,11 +77,11 @@ def from_url(cls, url, namespace: str = "/znsocket", **kwargs) -> "Client":
)

def __post_init__(self):

@self.sio.on("refresh", namespace=self.namespace)
def refresh(data: RefreshDataTypeDict):
for key in self.refresh_callbacks:
if data["target"] == key:
self.refresh_callbacks[key](data["data"])
def _(data: RefreshDataTypeDict):
for obj in self.attachments:
obj.refresh(target=data["target"], data=data["data"])

_url, _path = parse_url(self.address)
try:
Expand Down
65 changes: 61 additions & 4 deletions znsocket/objects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def __init__(
self.converter = converter
self._on_refresh = lambda x: None

self.nested_refresh = False
self.refresh_callback = None

if isinstance(r, Client):
self._pipeline_kwargs = {"max_commands_per_call": max_commands_per_call}
else:
Expand Down Expand Up @@ -309,11 +312,37 @@ def copy(self, key: str) -> "List":

return List(r=self.redis, key=key, socket=self.socket)

def on_refresh(self, callback: t.Callable[[RefreshDataTypeDict], None]) -> None:
def on_refresh(self, callback: t.Callable[[RefreshDataTypeDict], None], nested: bool = False) -> None:
if self.socket is None:
raise ValueError("No socket connection available")
self.nested_refresh = nested
self.socket.attachments.append(self)
self.refresh_callback = callback


def refresh(self, target: str, data: RefreshTypeDict) -> None:
if self.refresh_callback is None:
return

updated_keys = []

self.socket.refresh_callbacks[self.key] = callback
# If target does not match this object's key, look in nested dictionaries and lists.
if target != self.key:
if not self.nested_refresh:
return

for idx, value in enumerate(self):
if isinstance(value, Dict) and value.key == target:
updated_keys.append(idx)
elif isinstance(value, List) and value.key == target:
updated_keys.append(idx)

if updated_keys:
self.refresh_callback({"indices": updated_keys})

# If the target matches this object's key, call the callback with the provided data.
else:
self.refresh_callback(data)


class Dict(MutableMapping, ZnSocketObject):
Expand Down Expand Up @@ -360,6 +389,8 @@ def __init__(
"setitem": None,
"delitem": None,
}
self.refresh_callback = None
self.nested_refresh = False
if callbacks:
self._callbacks.update(callbacks)

Expand Down Expand Up @@ -473,11 +504,37 @@ def copy(self, key: str) -> "Dict":

return Dict(r=self.redis, key=key, socket=self.socket)

def on_refresh(self, callback: t.Callable[[RefreshDataTypeDict], None]) -> None:
def on_refresh(self, callback: t.Callable[[RefreshDataTypeDict], None], nested: bool = False) -> None:
if self.socket is None:
raise ValueError("No socket connection available")
self.nested_refresh = nested
self.socket.attachments.append(self)
self.refresh_callback = callback

self.socket.refresh_callbacks[self.key] = callback

def refresh(self, target: str, data: RefreshTypeDict) -> None:
if self.refresh_callback is None:
return

updated_keys = []

# If target does not match this object's key, look in nested dictionaries and lists.
if target != self.key:
if not self.nested_refresh:
return

for key, value in self.items():
if isinstance(value, Dict) and value.key == target:
updated_keys.append(key)
elif isinstance(value, List) and value.key == target:
updated_keys.append(key)

if updated_keys:
self.refresh_callback({"keys": updated_keys})

# If the target matches this object's key, call the callback with the provided data.
else:
self.refresh_callback(data)

def update(self, *args, **kwargs):
"""Update the dict with another dict or iterable."""
Expand Down

0 comments on commit ff7d89b

Please sign in to comment.