Skip to content

Commit 4b38b9d

Browse files
committed
this fixes #292
mode model graphs to `GraphMachien.model_graphs`; add long_description in markdown to setup.py; remove python 3.3 (EOL) from travis; add test_pickle_model to core; move DummyModel/TestContext to utils; let TestRep inherit from TestCore; adapt test to not directly use model.graph;
1 parent 2c6298f commit 4b38b9d

11 files changed

+81
-55
lines changed

.travis.yml

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ python:
33
- "3.6"
44
- "3.5"
55
- "3.4"
6-
- "3.3"
76
- "2.7"
87
install:
98
- pip install python-coveralls coverage

Changelog.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## 0.6.5 ()
44

55
- Feature #287: Embedding `HierarchicalMachine` will now reuse the machine's `initial` state. Passing `initial: False` overrides this(thanks @mrjogo).
6+
- Bug #292: Models using `GraphMashine` were not picklable in the past due to `graph` property. Graphs for each model are now stored in `GraphMachine.model_graphs` (thanks @ansumanm)
67

78
## 0.6.4 (January, 2018)
89

setup.py

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
with open('transitions/version.py') as f:
55
exec(f.read())
66

7+
with open('README.md') as file:
8+
long_description = file.read()
9+
710
if len(set(('test', 'easy_install')).intersection(sys.argv)) > 0:
811
import setuptools
912

@@ -22,6 +25,7 @@
2225
name="transitions",
2326
version=__version__,
2427
description="A lightweight, object-oriented Python state machine implementation.",
28+
long_description=long_description,
2529
maintainer='Tal Yarkoni',
2630
maintainer_email='[email protected]',
2731
url='http://github.com/pytransitions/transitions',

test.py

Whitespace-only changes.

tests/test_core.py

+14
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,20 @@ def test_pickle(self):
485485
self.assertEqual(m.state, m2.state)
486486
m2.run()
487487

488+
def test_pickle_model(self):
489+
import sys
490+
if sys.version_info < (3, 4):
491+
import dill as pickle
492+
else:
493+
import pickle
494+
495+
self.stuff.to_B()
496+
dump = pickle.dumps(self.stuff)
497+
self.assertIsNotNone(dump)
498+
model2 = pickle.loads(dump)
499+
self.assertEqual(self.stuff.state, model2.state)
500+
model2.to_F()
501+
488502
def test_queued(self):
489503
states = ['A', 'B', 'C', 'D']
490504
# Define with list of dictionaries

tests/test_graphing.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
pass
55

66
from .utils import Stuff
7+
from .test_core import TestTransitions
78

89
from transitions.extensions import MachineFactory
910
from transitions.extensions.diagrams import rep
1011
from transitions.extensions.nesting import NestedState
11-
from unittest import TestCase, skipIf
12+
from unittest import skipIf
1213
from functools import partial
1314
import tempfile
1415
import os
@@ -24,7 +25,7 @@ def edge_label_from_transition_label(label):
2425
return label.split(' | ')[0].split(' [')[0] # if no condition, label is returned; returns first event only
2526

2627

27-
class TestRep(TestCase):
28+
class TestRep(TestTransitions):
2829

2930
def test_rep_string(self):
3031
self.assertEqual(rep("string"), "string")
@@ -80,11 +81,12 @@ def __repr__(self):
8081

8182

8283
@skipIf(pgv is None, 'Graph diagram requires pygraphviz')
83-
class TestDiagrams(TestCase):
84+
class TestDiagrams(TestTransitions):
8485

85-
def setUp(self):
86-
self.machine_cls = MachineFactory.get_predefined(graph=True)
86+
machine_cls = MachineFactory.get_predefined(graph=True)
8787

88+
def setUp(self):
89+
self.stuff = Stuff(machine_cls=self.machine_cls)
8890
self.states = ['A', 'B', 'C', 'D']
8991
self.transitions = [
9092
{'trigger': 'walk', 'source': 'A', 'dest': 'B'},
@@ -154,12 +156,12 @@ def test_multi_model_state(self):
154156
m2 = Stuff(machine_cls=None)
155157
m = self.machine_cls(model=[m1, m2], states=self.states, transitions=self.transitions, initial='A')
156158
m1.walk()
157-
self.assertEqual(m1.graph.get_node(m1.state).attr['color'],
158-
m1.graph.style_attributes['node']['active']['color'])
159-
self.assertEqual(m2.graph.get_node(m1.state).attr['color'],
160-
m2.graph.style_attributes['node']['default']['color'])
159+
self.assertEqual(m1.get_graph().get_node(m1.state).attr['color'],
160+
m1.get_graph().style_attributes['node']['active']['color'])
161+
self.assertEqual(m2.get_graph().get_node(m1.state).attr['color'],
162+
m2.get_graph().style_attributes['node']['default']['color'])
161163
# backwards compatibility test
162-
self.assertEqual(m.get_graph(), m1.get_graph())
164+
self.assertEqual(id(m.get_graph()), id(m1.get_graph()))
163165

164166
def test_model_method_collision(self):
165167
class GraphModel:
@@ -202,16 +204,16 @@ def test_roi(self):
202204
@skipIf(pgv is None, 'Graph diagram requires pygraphviz')
203205
class TestDiagramsLocked(TestDiagrams):
204206

205-
def setUp(self):
206-
super(TestDiagramsLocked, self).setUp()
207-
self.machine_cls = MachineFactory.get_predefined(graph=True, locked=True)
207+
machine_cls = MachineFactory.get_predefined(graph=True, locked=True)
208208

209209

210210
@skipIf(pgv is None, 'NestedGraph diagram requires pygraphviz')
211211
class TestDiagramsNested(TestDiagrams):
212212

213+
machine_cls = MachineFactory.get_predefined(graph=True, nested=True)
214+
213215
def setUp(self):
214-
self.machine_cls = MachineFactory.get_predefined(graph=True, nested=True)
216+
super(TestDiagramsNested, self).setUp()
215217
self.states = ['A', 'B',
216218
{'name': 'C', 'children': [{'name': '1', 'children': ['a', 'b', 'c']},
217219
'2', '3']}, 'D']

tests/test_threading.py

+6-19
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from transitions.extensions.nesting import NestedState
1212
from .test_nesting import TestTransitions as TestsNested
1313
from .test_core import TestTransitions as TestCore
14-
from .utils import Stuff
14+
from .utils import Stuff, DummyModel, TestContext
1515

1616
try:
1717
from unittest.mock import MagicMock
@@ -166,28 +166,15 @@ def __exit__(self, *exc):
166166

167167
class TestMultipleContexts(TestCore):
168168

169-
class DummyModel(object):
170-
pass
171-
172-
class TestContext(object):
173-
def __init__(self, event_list):
174-
self._event_list = event_list
175-
176-
def __enter__(self):
177-
self._event_list.append((self, "enter"))
178-
179-
def __exit__(self, type, value, traceback):
180-
self._event_list.append((self, "exit"))
181-
182169
def setUp(self):
183170
self.event_list = []
184171

185-
self.s1 = self.DummyModel()
172+
self.s1 = DummyModel()
186173

187-
self.c1 = self.TestContext(event_list=self.event_list)
188-
self.c2 = self.TestContext(event_list=self.event_list)
189-
self.c3 = self.TestContext(event_list=self.event_list)
190-
self.c4 = self.TestContext(event_list=self.event_list)
174+
self.c1 = TestContext(event_list=self.event_list)
175+
self.c2 = TestContext(event_list=self.event_list)
176+
self.c3 = TestContext(event_list=self.event_list)
177+
self.c4 = TestContext(event_list=self.event_list)
191178

192179
self.stuff = Stuff(machine_cls=MachineFactory.get_predefined(locked=True), extra_kwargs={
193180
'machine_context': [self.c1, self.c2]

tests/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,18 @@ def __init__(self, states, initial='A'):
9090
@staticmethod
9191
def this_passes():
9292
return True
93+
94+
95+
class DummyModel(object):
96+
pass
97+
98+
99+
class TestContext(object):
100+
def __init__(self, event_list):
101+
self._event_list = event_list
102+
103+
def __enter__(self):
104+
self._event_list.append((self, "enter"))
105+
106+
def __exit__(self, type, value, traceback):
107+
self._event_list.append((self, "exit"))

tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[tox]
2-
envlist = py2, py3, py33, py34, py35, codestyle, check-manifest
2+
envlist = py2, py3, py34, py35, py36, codestyle, check-manifest
33
skip_missing_interpreters = True
44

55
[testenv]

transitions/core.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ def execute(self, event_data):
252252

253253
for cond in self.conditions:
254254
if not cond.check(event_data):
255-
_LOGGER.debug("%sTransition condition failed: %s() does not " +
256-
"return %s. Transition halted.", event_data.machine.name, cond.func, cond.target)
255+
_LOGGER.debug("%sTransition condition failed: %s() does not return %s. Transition halted.",
256+
event_data.machine.name, cond.func, cond.target)
257257
return False
258258
for func in itertools.chain(machine.before_state_change, self.before):
259259
machine.callback(func, event_data)

transitions/extensions/diagrams.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -274,16 +274,16 @@ def _change_state(self, event_data):
274274
machine = event_data.machine
275275
model = event_data.model
276276
dest = machine.get_state(self.dest)
277+
graph = model.get_graph()
277278

278279
# Mark the active node
279-
machine.reset_graph_style(model.graph)
280+
machine.reset_graph_style(graph)
280281

281282
# Mark the previous node and path used
282283
if self.source is not None:
283284
source = machine.get_state(self.source)
284-
machine.set_node_state(model.graph, source.name,
285-
state='previous')
286-
machine.set_node_state(model.graph, dest.name, state='active')
285+
machine.set_node_state(graph, source.name, state='previous')
286+
machine.set_node_state(graph, dest.name, state='active')
287287

288288
if getattr(source, 'children', []):
289289
source = source.name + '_anchor'
@@ -293,8 +293,7 @@ def _change_state(self, event_data):
293293
dest = dest.name + '_anchor'
294294
else:
295295
dest = dest.name
296-
machine.set_edge_state(model.graph, source, dest,
297-
state='previous', label=event_data.event.name)
296+
machine.set_edge_state(graph, source, dest, state='previous', label=event_data.event.name)
298297

299298
_super(TransitionGraphSupport, self)._change_state(event_data) # pylint: disable=protected-access
300299

@@ -307,23 +306,29 @@ class GraphMachine(Machine):
307306
transition_cls (cls): TransitionGraphSupport
308307
"""
309308

310-
_pickle_blacklist = ['graph']
309+
_pickle_blacklist = ['model_graphs']
311310
transition_cls = TransitionGraphSupport
312311

312+
# model_graphs cannot be pickled. Omit them.
313313
def __getstate__(self):
314314
return {k: v for k, v in self.__dict__.items() if k not in self._pickle_blacklist}
315315

316316
def __setstate__(self, state):
317317
self.__dict__.update(state)
318+
self.model_graphs = {} # reinitialize new model_graphs
318319
for model in self.models:
319-
graph = self._get_graph(model, title=self.title)
320-
self.set_node_style(graph, model.state, 'active')
320+
try:
321+
graph = self._get_graph(model, title=self.title)
322+
self.set_node_style(graph, model.state, 'active')
323+
except AttributeError:
324+
_LOGGER.warning("Graph for model could not be initialized after pickling.")
321325

322326
def __init__(self, *args, **kwargs):
323327
# remove graph config from keywords
324328
self.title = kwargs.pop('title', 'State Machine')
325329
self.show_conditions = kwargs.pop('show_conditions', False)
326330
self.show_auto_transitions = kwargs.pop('show_auto_transitions', False)
331+
self.model_graphs = {}
327332

328333
# Update March 2017: This temporal overwrite does not work
329334
# well with inheritance. Since the tests pass I will disable it
@@ -346,7 +351,7 @@ def __init__(self, *args, **kwargs):
346351
raise AttributeError('Model already has a get_graph attribute. Graph retrieval cannot be bound.')
347352
setattr(model, 'get_graph', partial(self._get_graph, model))
348353
model.get_graph()
349-
self.set_node_state(model.graph, self.initial, 'active')
354+
self.set_node_state(self.model_graphs[model], self.initial, 'active')
350355

351356
# for backwards compatibility assign get_combined_graph to get_graph
352357
# if model is not the machine
@@ -356,12 +361,12 @@ def __init__(self, *args, **kwargs):
356361
def _get_graph(self, model, title=None, force_new=False, show_roi=False):
357362
if title is None:
358363
title = self.title
359-
if not hasattr(model, 'graph') or force_new:
360-
model.graph = NestedGraph(self).get_graph(title) if isinstance(self, HierarchicalMachine) \
364+
if model not in self.model_graphs or force_new:
365+
self.model_graphs[model] = NestedGraph(self).get_graph(title) if isinstance(self, HierarchicalMachine) \
361366
else Graph(self).get_graph(title)
362-
self.set_node_state(model.graph, model.state, state='active')
367+
self.set_node_state(self.model_graphs[model], model.state, state='active')
363368

364-
return model.graph if not show_roi else self._graph_roi(model)
369+
return self.model_graphs[model] if not show_roi else self._graph_roi(model)
365370

366371
def get_combined_graph(self, title=None, force_new=False, show_roi=False):
367372
""" This method is currently equivalent to 'get_graph' of the first machine's model.
@@ -374,7 +379,7 @@ def get_combined_graph(self, title=None, force_new=False, show_roi=False):
374379
the current state.
375380
Returns: AGraph of the first machine's model.
376381
"""
377-
_LOGGER.info('Returning graph of the first model. In future releases, this ' +
382+
_LOGGER.info('Returning graph of the first model. In future releases, this '
378383
'method will return a combined graph of all models.')
379384
return self._get_graph(self.models[0], title, force_new, show_roi)
380385

@@ -439,9 +444,8 @@ def set_node_state(self, graph, node_name, state='default'):
439444
func = self.set_graph_style
440445
func(graph, node, state)
441446

442-
@staticmethod
443-
def _graph_roi(model):
444-
graph = model.graph
447+
def _graph_roi(self, model):
448+
graph = model.get_graph()
445449
filtered = graph.copy()
446450

447451
kept_nodes = set()

0 commit comments

Comments
 (0)