Skip to content

Commit 05ec3ed

Browse files
committed
test(experimental): add tests for generate_base_model and trigger decorators
1 parent 6f6df0f commit 05ec3ed

File tree

7 files changed

+182
-35
lines changed

7 files changed

+182
-35
lines changed

tests/test_experimental.py

+170-23
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,193 @@
1+
from typing import TYPE_CHECKING
12
from unittest import TestCase
3+
from types import ModuleType
24
from unittest.mock import MagicMock
35

46
from transitions import Machine
5-
from transitions.experimental.decoration import expect_override
7+
from transitions.experimental.typing import generate_base_model
8+
from transitions.experimental.decoration import trigger, with_trigger_decorator
9+
from transitions.extensions import HierarchicalMachine
10+
11+
from .utils import Stuff
12+
13+
if TYPE_CHECKING:
14+
from transitions.core import MachineConfig
15+
from typing import Type
16+
17+
18+
def import_code(code: str, name: str) -> ModuleType:
19+
module = ModuleType(name)
20+
exec(code, module.__dict__)
21+
return module
622

723

824
class TestExperimental(TestCase):
925

1026
def setUp(self) -> None:
11-
self.machine_cls = Machine
12-
return super().setUp()
27+
self.machine_cls = Machine # type: Type[Machine]
28+
self.create_trigger_class()
1329

14-
def test_override_decorator(self):
15-
b_mock = MagicMock()
16-
c_mock = MagicMock()
30+
def create_trigger_class(self):
31+
@with_trigger_decorator
32+
class TriggerMachine(self.machine_cls): # type: ignore
33+
pass
34+
35+
self.trigger_machine = TriggerMachine
36+
37+
def test_model_override(self):
1738

1839
class Model:
1940

20-
@expect_override
21-
def is_A(self) -> bool:
41+
def trigger(self, name: str) -> bool:
2242
raise RuntimeError("Should be overridden")
2343

24-
def is_B(self) -> bool:
25-
b_mock()
26-
return False
44+
def is_A(self) -> bool:
45+
raise RuntimeError("Should be overridden")
2746

28-
@expect_override
2947
def is_C(self) -> bool:
30-
c_mock()
31-
return False
48+
raise RuntimeError("Should be overridden")
3249

3350
model = Model()
34-
machine = self.machine_cls(model, states=["A", "B"], initial="A")
51+
machine = self.machine_cls(model, states=["A", "B"], initial="A", model_override=True)
3552
self.assertTrue(model.is_A())
36-
self.assertTrue(model.to_B())
37-
self.assertFalse(model.is_B()) # not overridden with convenience function
38-
self.assertTrue(b_mock.called)
39-
self.assertFalse(model.is_C()) # not overridden yet
40-
self.assertTrue(c_mock.called)
53+
with self.assertRaises(AttributeError):
54+
model.to_B() # type: ignore # Should not be assigned to model since its not declared
55+
self.assertTrue(model.trigger("to_B"))
56+
self.assertFalse(model.is_A())
57+
with self.assertRaises(RuntimeError):
58+
model.is_C() # not overridden yet
4159
machine.add_state("C")
4260
self.assertFalse(model.is_C()) # now it is!
43-
self.assertEqual(1, c_mock.call_count) # call_count is not increased
44-
self.assertTrue(model.to_C())
61+
self.assertTrue(model.trigger("to_C"))
4562
self.assertTrue(model.is_C())
46-
self.assertEqual(1, c_mock.call_count)
63+
64+
def test_generate_base_model(self):
65+
simple_config = {
66+
"states": ["A", "B"],
67+
"transitions": [
68+
["go", "A", "B"],
69+
["back", "*", "A"]
70+
],
71+
"initial": "A",
72+
"model_override": True
73+
} # type: MachineConfig
74+
75+
mod = import_code(generate_base_model(simple_config), "base_module")
76+
model = mod.BaseModel()
77+
machine = self.machine_cls(model, **simple_config)
78+
self.assertTrue(model.is_A())
79+
self.assertTrue(model.go())
80+
self.assertTrue(model.is_B())
81+
self.assertTrue(model.back())
82+
self.assertTrue(model.state == "A")
83+
with self.assertRaises(AttributeError):
84+
model.is_C()
85+
86+
def test_generate_base_model_callbacks(self):
87+
simple_config = {
88+
"states": ["A", "B"],
89+
"transitions": [
90+
["go", "A", "B"],
91+
],
92+
"initial": "A",
93+
"model_override": True,
94+
"before_state_change": "call_this"
95+
} # type: MachineConfig
96+
97+
mod = import_code(generate_base_model(simple_config), "base_module")
98+
mock = MagicMock()
99+
100+
class Model(mod.BaseModel): # type: ignore
101+
102+
@staticmethod
103+
def call_this() -> None:
104+
mock()
105+
106+
model = Model()
107+
machine = self.machine_cls(model, **simple_config)
108+
self.assertTrue(model.is_A())
109+
self.assertTrue(model.go())
110+
self.assertTrue(mock.called)
111+
112+
def test_generate_model_no_auto(self):
113+
simple_config: MachineConfig = {
114+
"states": ["A", "B"],
115+
"auto_transitions": False,
116+
"model_override": True,
117+
"transitions": [
118+
["go", "A", "B"],
119+
["back", "*", "A"]
120+
],
121+
"initial": "A"
122+
}
123+
mod = import_code(generate_base_model(simple_config), "base_module")
124+
model = mod.BaseModel()
125+
machine = self.machine_cls(model, **simple_config)
126+
self.assertTrue(model.is_A())
127+
self.assertTrue(model.go())
128+
with self.assertRaises(AttributeError):
129+
model.to_B()
130+
131+
def test_trigger_decorator(self):
132+
133+
class Model:
134+
135+
state: str = ""
136+
137+
def is_B(self) -> bool:
138+
return False
139+
140+
@trigger(source="A", dest="B")
141+
@trigger(source=["A", "B"], dest="C")
142+
def go(self) -> bool:
143+
raise RuntimeError("Should be overridden!")
144+
145+
model = Model()
146+
machine = self.trigger_machine(model, states=["A", "B", "C"], initial="A")
147+
self.assertEqual("A", model.state)
148+
self.assertTrue(machine.is_state("A", model))
149+
self.assertTrue(model.go())
150+
with self.assertRaises(AttributeError):
151+
model.is_A() # type: ignore
152+
self.assertEqual("B", model.state)
153+
self.assertTrue(model.is_B())
154+
self.assertTrue(model.go())
155+
self.assertFalse(model.is_B())
156+
self.assertEqual("C", model.state)
157+
158+
def test_trigger_decorator_complex(self):
159+
160+
class Model:
161+
162+
state: str = ""
163+
164+
def check_param(self, param: bool) -> bool:
165+
return param
166+
167+
@trigger(source="A", dest="B")
168+
@trigger(source="B", dest="C", unless=Stuff.this_passes)
169+
@trigger(source="B", dest="A", conditions=Stuff.this_passes, unless=Stuff.this_fails)
170+
def go(self) -> bool:
171+
raise RuntimeError("Should be overridden")
172+
173+
@trigger(source="A", dest="B", conditions="check_param")
174+
def event(self, param) -> bool:
175+
raise RuntimeError("Should be overridden")
176+
177+
model = Model()
178+
machine = self.trigger_machine(model, states=["A", "B"], initial="A")
179+
self.assertTrue(model.go())
180+
self.assertTrue(model.state == "B")
181+
self.assertTrue(model.go())
182+
self.assertTrue(model.state == "A")
183+
self.assertFalse(model.event(param=False))
184+
self.assertTrue(model.state == "A")
185+
self.assertTrue(model.event(param=True))
186+
self.assertTrue(model.state == "B")
187+
188+
189+
class TestHSMExperimental(TestExperimental):
190+
191+
def setUp(self):
192+
self.machine_cls = HierarchicalMachine # type: Type[HierarchicalMachine]
193+
self.create_trigger_class()

tests/test_graphviz.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def test_function_callbacks_annotation(self):
314314
class TestDiagramsNested(TestDiagrams):
315315

316316
machine_cls = HierarchicalGraphMachine \
317-
# type: Type[HierarchicalGraphMachine | LockedHierarchicalGraphMachine]
317+
# type: Type[Union[HierarchicalGraphMachine, LockedHierarchicalGraphMachine]]
318318

319319
def setUp(self):
320320
super(TestDiagramsNested, self).setUp()

transitions/core.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ ModelParameter = Union[Union[Literal['self'], Any], List[Union[Literal['self'],
1818

1919

2020
class MachineConfig(TypedDict, total=False):
21-
states: list[StateIdentifier]
22-
transitions: list[TransitionConfig]
21+
states: List[StateIdentifier]
22+
transitions: List[TransitionConfig]
2323
initial: str
2424
auto_transitions: bool
2525
send_event: bool
+3-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import Union, Type, Callable
1+
from typing import Union, Type, Callable, List, Optional
22
from transitions.core import StateIdentifier, CallbacksArg, CallbackFunc, Machine
33

44

55
def with_trigger_decorator(cls: Type[Machine]) -> Type[Machine]: ...
66

7-
def trigger(source: Union[StateIdentifier, list[StateIdentifier]],
8-
dest: StateIdentifier | None = ...,
7+
def trigger(source: Union[StateIdentifier, List[StateIdentifier]],
8+
dest: Optional[StateIdentifier] = ...,
99
conditions: CallbacksArg = ..., unless: CallbacksArg = ...,
1010
before: CallbacksArg = ..., after: CallbacksArg = ...,
1111
prepare: CallbacksArg = ...) -> Callable[[CallbackFunc], CallbackFunc]: ...

transitions/experimental/typing.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ def generate_base_model(config):
6262
f" def may_{trigger_name}(self) -> bool: {_placeholder_body}\n"
6363
)
6464

65-
extra_params = "event_data: EventData" if m.send_event else "*args: list[Any], **kwargs: dict[str, Any]"
65+
extra_params = "event_data: EventData" if m.send_event else "*args: List[Any], **kwargs: Dict[str, Any]"
6666
for callback_name in callbacks:
6767
if isinstance(callback_name, str):
6868
callback_block += (f" @abstractmethod\n"
69-
f" def {callback_name}(self, {extra_params}) -> bool | None: ...\n")
69+
f" def {callback_name}(self, {extra_params}) -> Optional[bool]: ...\n")
7070

7171
template = f"""# autogenerated by transitions
7272
from abc import ABCMeta, abstractmethod
73-
from typing import Callable, ParamSpec, Union, Optional, Tuple, TYPE_CHECKING, Any
73+
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
7474
7575
if TYPE_CHECKING:
7676
from transitions.core import CallbacksArg, StateIdentifier, EventData

transitions/experimental/typing.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Any
1+
from typing import Union
22
from transitions.core import MachineConfig
33
from transitions.extensions.markup import MarkupConfig
44

55
_placeholder_body: str
66

7-
def generate_base_model(config: MachineConfig | MarkupConfig) -> str: ...
7+
def generate_base_model(config: Union[MachineConfig, MarkupConfig]) -> str: ...

transitions/extensions/markup.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class MarkupConfig(TypedDict):
2121
send_event: bool
2222
auto_transitions: bool
2323
ignore_invalid_triggers: bool
24-
queued: bool | str
24+
queued: Union[bool, str]
2525

2626

2727
class MarkupMachine(Machine):

0 commit comments

Comments
 (0)