Skip to content

Commit 92d11b0

Browse files
committed
Moved files to new branch to avoid weird git bug
Signed-off-by: Lorenzo Curcio <[email protected]>
1 parent 6e90e84 commit 92d11b0

File tree

6 files changed

+693
-3
lines changed

6 files changed

+693
-3
lines changed

dapr/actor/runtime/mock_actor.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
Copyright 2023 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
from datetime import timedelta
15+
from typing import Any, Optional, TypeVar
16+
17+
from dapr.actor.id import ActorId
18+
from dapr.actor.runtime._reminder_data import ActorReminderData
19+
from dapr.actor.runtime._timer_data import TIMER_CALLBACK, ActorTimerData
20+
from dapr.actor.runtime.actor import Actor
21+
from dapr.actor.runtime.mock_state_manager import MockStateManager
22+
from dapr.actor.runtime.state_manager import ActorStateManager
23+
24+
25+
class MockActor(Actor):
26+
"""A mock actor class to be used to override certain Actor methods for unit testing.
27+
To be used only via the create_mock_actor function, which takes in a class and returns a
28+
mock actor object for that class.
29+
30+
Examples:
31+
class SomeActorInterface(ActorInterface):
32+
@actor_method(name="method")
33+
async def set_state(self, data: dict) -> None:
34+
35+
class SomeActor(Actor, SomeActorInterface):
36+
async def set_state(self, data: dict) -> None:
37+
await self._state_manager.set_state('state', data)
38+
await self._state_manager.save_state()
39+
40+
mock_actor = create_mock_actor(SomeActor)
41+
assert mock_actor._state_manager._mock_state == {}
42+
await mock_actor.set_state({"test":10})
43+
assert mock_actor._state_manager._mock_state == {"test":10}
44+
"""
45+
46+
def __init__(self, actor_id: str, initstate: Optional[dict]):
47+
self.id = ActorId(actor_id)
48+
self._runtime_ctx = None
49+
self._state_manager: ActorStateManager = MockStateManager(self, initstate)
50+
51+
async def register_timer(
52+
self,
53+
name: Optional[str],
54+
callback: TIMER_CALLBACK,
55+
state: Any,
56+
due_time: timedelta,
57+
period: timedelta,
58+
ttl: Optional[timedelta] = None,
59+
) -> None:
60+
"""Adds actor timer to self._state_manager._mock_timers.
61+
Args:
62+
name (str): the name of the timer to register.
63+
callback (Callable): An awaitable callable which will be called when the timer fires.
64+
state (Any): An object which will pass to the callback method, or None.
65+
due_time (datetime.timedelta): the amount of time to delay before the awaitable
66+
callback is first invoked.
67+
period (datetime.timedelta): the time interval between invocations
68+
of the awaitable callback.
69+
ttl (Optional[datetime.timedelta]): the time interval before the timer stops firing
70+
"""
71+
name = name or self.__get_new_timer_name()
72+
timer = ActorTimerData(name, callback, state, due_time, period, ttl)
73+
self._state_manager._mock_timers[name] = timer
74+
75+
async def unregister_timer(self, name: str) -> None:
76+
"""Unregisters actor timer from self._state_manager._mock_timers.
77+
78+
Args:
79+
name (str): the name of the timer to unregister.
80+
"""
81+
self._state_manager._mock_timers.pop(name, None)
82+
83+
async def register_reminder(
84+
self,
85+
name: str,
86+
state: bytes,
87+
due_time: timedelta,
88+
period: timedelta,
89+
ttl: Optional[timedelta] = None,
90+
) -> None:
91+
"""Adds actor reminder to self._state_manager._mock_reminders.
92+
93+
Args:
94+
name (str): the name of the reminder to register. the name must be unique per actor.
95+
state (bytes): the user state passed to the reminder invocation.
96+
due_time (datetime.timedelta): the amount of time to delay before invoking the reminder
97+
for the first time.
98+
period (datetime.timedelta): the time interval between reminder invocations after
99+
the first invocation.
100+
ttl (datetime.timedelta): the time interval before the reminder stops firing
101+
"""
102+
reminder = ActorReminderData(name, state, due_time, period, ttl)
103+
self._state_manager._mock_reminders[name] = reminder
104+
105+
async def unregister_reminder(self, name: str) -> None:
106+
"""Unregisters actor reminder from self._state_manager._mock_reminders..
107+
108+
Args:
109+
name (str): the name of the reminder to unregister.
110+
"""
111+
self._state_manager._mock_reminders.pop(name, None)
112+
113+
114+
T = TypeVar('T', bound=Actor)
115+
116+
117+
def create_mock_actor(cls1: type[T], actor_id: str, initstate: Optional[dict] = None) -> T:
118+
class MockSuperClass(MockActor, cls1):
119+
pass
120+
121+
return MockSuperClass(actor_id, initstate) # type: ignore
+235
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
"""
2+
Copyright 2023 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
import asyncio
15+
from contextvars import ContextVar
16+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar
17+
18+
from dapr.actor.runtime._reminder_data import ActorReminderData
19+
from dapr.actor.runtime._timer_data import ActorTimerData
20+
from dapr.actor.runtime.state_change import ActorStateChange, StateChangeKind
21+
from dapr.actor.runtime.state_manager import ActorStateManager, StateMetadata
22+
23+
if TYPE_CHECKING:
24+
from dapr.actor.runtime.mock_actor import MockActor
25+
26+
T = TypeVar('T')
27+
CONTEXT: ContextVar[Optional[Dict[str, Any]]] = ContextVar('state_tracker_context')
28+
29+
30+
class MockStateManager(ActorStateManager):
31+
def __init__(self, actor: 'MockActor', initstate: Optional[dict]):
32+
self._actor = actor
33+
self._default_state_change_tracker: Dict[str, StateMetadata] = {}
34+
self._mock_state: dict[str, Any] = {}
35+
self._mock_timers: dict[str, ActorTimerData] = {}
36+
self._mock_reminders: dict[str, ActorReminderData] = {}
37+
if initstate:
38+
self._mock_state = initstate
39+
40+
async def add_state(self, state_name: str, value: T) -> None:
41+
if not await self.try_add_state(state_name, value):
42+
raise ValueError(f'The actor state name {state_name} already exist.')
43+
44+
async def try_add_state(self, state_name: str, value: T) -> bool:
45+
if state_name in self._default_state_change_tracker:
46+
state_metadata = self._default_state_change_tracker[state_name]
47+
if state_metadata.change_kind == StateChangeKind.remove:
48+
self._default_state_change_tracker[state_name] = StateMetadata(
49+
value, StateChangeKind.update
50+
)
51+
return True
52+
return False
53+
existed = state_name in self._mock_state
54+
if not existed:
55+
return False
56+
self._default_state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.add)
57+
self._mock_state[state_name] = value
58+
return True
59+
60+
async def get_state(self, state_name: str) -> Optional[T]:
61+
has_value, val = await self.try_get_state(state_name)
62+
if has_value:
63+
return val
64+
else:
65+
raise KeyError(f'Actor State with name {state_name} was not found.')
66+
67+
async def try_get_state(self, state_name: str) -> Tuple[bool, Optional[T]]:
68+
if state_name in self._default_state_change_tracker:
69+
state_metadata = self._default_state_change_tracker[state_name]
70+
if state_metadata.change_kind == StateChangeKind.remove:
71+
return False, None
72+
return True, state_metadata.value
73+
has_value = state_name in self._mock_state
74+
val = self._mock_state.get(state_name)
75+
if has_value:
76+
self._default_state_change_tracker[state_name] = StateMetadata(
77+
val, StateChangeKind.none
78+
)
79+
return has_value, val
80+
81+
async def set_state(self, state_name: str, value: T) -> None:
82+
await self.set_state_ttl(state_name, value, None)
83+
84+
async def set_state_ttl(self, state_name: str, value: T, ttl_in_seconds: Optional[int]) -> None:
85+
if ttl_in_seconds is not None and ttl_in_seconds < 0:
86+
return
87+
88+
if state_name in self._default_state_change_tracker:
89+
state_metadata = self._default_state_change_tracker[state_name]
90+
state_metadata.value = value
91+
state_metadata.ttl_in_seconds = ttl_in_seconds
92+
93+
if (
94+
state_metadata.change_kind == StateChangeKind.none
95+
or state_metadata.change_kind == StateChangeKind.remove
96+
):
97+
state_metadata.change_kind = StateChangeKind.update
98+
self._default_state_change_tracker[state_name] = state_metadata
99+
self._mock_state[state_name] = value
100+
return
101+
102+
existed = state_name in self._mock_state
103+
if existed:
104+
self._default_state_change_tracker[state_name] = StateMetadata(
105+
value, StateChangeKind.update, ttl_in_seconds
106+
)
107+
else:
108+
self._default_state_change_tracker[state_name] = StateMetadata(
109+
value, StateChangeKind.add, ttl_in_seconds
110+
)
111+
self._mock_state[state_name] = value
112+
113+
async def remove_state(self, state_name: str) -> None:
114+
if not await self.try_remove_state(state_name):
115+
raise KeyError(f'Actor State with name {state_name} was not found.')
116+
117+
async def try_remove_state(self, state_name: str) -> bool:
118+
if state_name in self._default_state_change_tracker:
119+
state_metadata = self._default_state_change_tracker[state_name]
120+
if state_metadata.change_kind == StateChangeKind.remove:
121+
return False
122+
elif state_metadata.change_kind == StateChangeKind.add:
123+
self._default_state_change_tracker.pop(state_name, None)
124+
self._mock_state.pop(state_name, None)
125+
return True
126+
self._mock_state.pop(state_name, None)
127+
state_metadata.change_kind = StateChangeKind.remove
128+
return True
129+
130+
existed = state_name in self._mock_state
131+
if existed:
132+
self._default_state_change_tracker[state_name] = StateMetadata(
133+
None, StateChangeKind.remove
134+
)
135+
self._mock_state.pop(state_name, None)
136+
return True
137+
return False
138+
139+
async def contains_state(self, state_name: str) -> bool:
140+
if state_name in self._default_state_change_tracker:
141+
state_metadata = self._default_state_change_tracker[state_name]
142+
return state_metadata.change_kind != StateChangeKind.remove
143+
return state_name in self._mock_state
144+
145+
async def get_or_add_state(self, state_name: str, value: T) -> Optional[T]:
146+
has_value, val = await self.try_get_state(state_name)
147+
if has_value:
148+
return val
149+
change_kind = (
150+
StateChangeKind.update
151+
if self.is_state_marked_for_remove(state_name)
152+
else StateChangeKind.add
153+
)
154+
self._default_state_change_tracker[state_name] = StateMetadata(value, change_kind)
155+
return value
156+
157+
async def add_or_update_state(
158+
self, state_name: str, value: T, update_value_factory: Callable[[str, T], T]
159+
) -> T:
160+
if not callable(update_value_factory):
161+
raise AttributeError('update_value_factory is not callable')
162+
163+
if state_name in self._default_state_change_tracker:
164+
state_metadata = self._default_state_change_tracker[state_name]
165+
if state_metadata.change_kind == StateChangeKind.remove:
166+
self._default_state_change_tracker[state_name] = StateMetadata(
167+
value, StateChangeKind.update
168+
)
169+
self._mock_state[state_name] = value
170+
return value
171+
new_value = update_value_factory(state_name, state_metadata.value)
172+
state_metadata.value = new_value
173+
if state_metadata.change_kind == StateChangeKind.none:
174+
state_metadata.change_kind = StateChangeKind.update
175+
self._default_state_change_tracker[state_name] = state_metadata
176+
self._mock_state[state_name] = value
177+
return new_value
178+
179+
has_value = state_name in self._mock_state
180+
val: Any = self._mock_state.get(state_name)
181+
if has_value:
182+
new_value = update_value_factory(state_name, val)
183+
self._default_state_change_tracker[state_name] = StateMetadata(
184+
new_value, StateChangeKind.update
185+
)
186+
return new_value
187+
self._default_state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.add)
188+
return value
189+
190+
async def get_state_names(self) -> List[str]:
191+
# TODO: Get all state names from Dapr once implemented.
192+
def append_names_sync():
193+
state_names = []
194+
for key, value in self._default_state_change_tracker.items():
195+
if value.change_kind == StateChangeKind.add:
196+
state_names.append(key)
197+
elif value.change_kind == StateChangeKind.remove:
198+
state_names.append(key)
199+
return state_names
200+
201+
default_loop = asyncio.get_running_loop()
202+
return await default_loop.run_in_executor(None, append_names_sync)
203+
204+
async def clear_cache(self) -> None:
205+
self._default_state_change_tracker.clear()
206+
207+
async def save_state(self) -> None:
208+
if len(self._default_state_change_tracker) == 0:
209+
return
210+
211+
state_changes = []
212+
states_to_remove = []
213+
for state_name, state_metadata in self._default_state_change_tracker.items():
214+
if state_metadata.change_kind == StateChangeKind.none:
215+
continue
216+
state_changes.append(
217+
ActorStateChange(
218+
state_name,
219+
state_metadata.value,
220+
state_metadata.change_kind,
221+
state_metadata.ttl_in_seconds,
222+
)
223+
)
224+
if state_metadata.change_kind == StateChangeKind.remove:
225+
states_to_remove.append(state_name)
226+
# Mark the states as unmodified so that tracking for next invocation is done correctly.
227+
state_metadata.change_kind = StateChangeKind.none
228+
for state_name in states_to_remove:
229+
self._default_state_change_tracker.pop(state_name, None)
230+
231+
def is_state_marked_for_remove(self, state_name: str) -> bool:
232+
return (
233+
state_name in self._default_state_change_tracker
234+
and self._default_state_change_tracker[state_name].change_kind == StateChangeKind.remove
235+
)

dapr/actor/runtime/state_manager.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515

1616
import asyncio
1717
from contextvars import ContextVar
18+
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar
1819

19-
from dapr.actor.runtime.state_change import StateChangeKind, ActorStateChange
2020
from dapr.actor.runtime.reentrancy_context import reentrancy_ctx
21-
22-
from typing import Any, Callable, Dict, Generic, List, Tuple, TypeVar, Optional, TYPE_CHECKING
21+
from dapr.actor.runtime.state_change import ActorStateChange, StateChangeKind
2322

2423
if TYPE_CHECKING:
24+
from dapr.actor.runtime._reminder_data import ActorReminderData
25+
from dapr.actor.runtime._timer_data import ActorTimerData
2526
from dapr.actor.runtime.actor import Actor
2627

2728
T = TypeVar('T')
@@ -69,6 +70,9 @@ def __init__(self, actor: 'Actor'):
6970
self._type_name = actor.runtime_ctx.actor_type_info.type_name
7071

7172
self._default_state_change_tracker: Dict[str, StateMetadata] = {}
73+
self._mock_state: dict[str, Any]
74+
self._mock_timers: dict[str, ActorTimerData]
75+
self._mock_reminders: dict[str, ActorReminderData]
7276

7377
async def add_state(self, state_name: str, value: T) -> None:
7478
if not await self.try_add_state(state_name, value):

0 commit comments

Comments
 (0)