@@ -274,16 +274,16 @@ def _change_state(self, event_data):
274
274
machine = event_data .machine
275
275
model = event_data .model
276
276
dest = machine .get_state (self .dest )
277
+ graph = model .get_graph ()
277
278
278
279
# Mark the active node
279
- machine .reset_graph_style (model . graph )
280
+ machine .reset_graph_style (graph )
280
281
281
282
# Mark the previous node and path used
282
283
if self .source is not None :
283
284
source = machine .get_state (self .source )
284
- machine .set_node_state (model .graph , source .name ,
285
- state = 'previous' )
286
- machine .set_node_state (model .graph , dest .name , state = 'active' )
285
+ machine .set_node_state (graph , source .name , state = 'previous' )
286
+ machine .set_node_state (graph , dest .name , state = 'active' )
287
287
288
288
if getattr (source , 'children' , []):
289
289
source = source .name + '_anchor'
@@ -293,8 +293,7 @@ def _change_state(self, event_data):
293
293
dest = dest .name + '_anchor'
294
294
else :
295
295
dest = dest .name
296
- machine .set_edge_state (model .graph , source , dest ,
297
- state = 'previous' , label = event_data .event .name )
296
+ machine .set_edge_state (graph , source , dest , state = 'previous' , label = event_data .event .name )
298
297
299
298
_super (TransitionGraphSupport , self )._change_state (event_data ) # pylint: disable=protected-access
300
299
@@ -307,23 +306,29 @@ class GraphMachine(Machine):
307
306
transition_cls (cls): TransitionGraphSupport
308
307
"""
309
308
310
- _pickle_blacklist = ['graph ' ]
309
+ _pickle_blacklist = ['model_graphs ' ]
311
310
transition_cls = TransitionGraphSupport
312
311
312
+ # model_graphs cannot be pickled. Omit them.
313
313
def __getstate__ (self ):
314
314
return {k : v for k , v in self .__dict__ .items () if k not in self ._pickle_blacklist }
315
315
316
316
def __setstate__ (self , state ):
317
317
self .__dict__ .update (state )
318
+ self .model_graphs = {} # reinitialize new model_graphs
318
319
for model in self .models :
319
- graph = self ._get_graph (model , title = self .title )
320
- self .set_node_style (graph , model .state , 'active' )
320
+ try :
321
+ graph = self ._get_graph (model , title = self .title )
322
+ self .set_node_style (graph , model .state , 'active' )
323
+ except AttributeError :
324
+ _LOGGER .warning ("Graph for model could not be initialized after pickling." )
321
325
322
326
def __init__ (self , * args , ** kwargs ):
323
327
# remove graph config from keywords
324
328
self .title = kwargs .pop ('title' , 'State Machine' )
325
329
self .show_conditions = kwargs .pop ('show_conditions' , False )
326
330
self .show_auto_transitions = kwargs .pop ('show_auto_transitions' , False )
331
+ self .model_graphs = {}
327
332
328
333
# Update March 2017: This temporal overwrite does not work
329
334
# well with inheritance. Since the tests pass I will disable it
@@ -346,7 +351,7 @@ def __init__(self, *args, **kwargs):
346
351
raise AttributeError ('Model already has a get_graph attribute. Graph retrieval cannot be bound.' )
347
352
setattr (model , 'get_graph' , partial (self ._get_graph , model ))
348
353
model .get_graph ()
349
- self .set_node_state (model . graph , self .initial , 'active' )
354
+ self .set_node_state (self . model_graphs [ model ] , self .initial , 'active' )
350
355
351
356
# for backwards compatibility assign get_combined_graph to get_graph
352
357
# if model is not the machine
@@ -356,12 +361,12 @@ def __init__(self, *args, **kwargs):
356
361
def _get_graph (self , model , title = None , force_new = False , show_roi = False ):
357
362
if title is None :
358
363
title = self .title
359
- if not hasattr ( model , 'graph' ) or force_new :
360
- model . graph = NestedGraph (self ).get_graph (title ) if isinstance (self , HierarchicalMachine ) \
364
+ if model not in self . model_graphs or force_new :
365
+ self . model_graphs [ model ] = NestedGraph (self ).get_graph (title ) if isinstance (self , HierarchicalMachine ) \
361
366
else Graph (self ).get_graph (title )
362
- self .set_node_state (model . graph , model .state , state = 'active' )
367
+ self .set_node_state (self . model_graphs [ model ] , model .state , state = 'active' )
363
368
364
- return model . graph if not show_roi else self ._graph_roi (model )
369
+ return self . model_graphs [ model ] if not show_roi else self ._graph_roi (model )
365
370
366
371
def get_combined_graph (self , title = None , force_new = False , show_roi = False ):
367
372
""" This method is currently equivalent to 'get_graph' of the first machine's model.
@@ -374,7 +379,7 @@ def get_combined_graph(self, title=None, force_new=False, show_roi=False):
374
379
the current state.
375
380
Returns: AGraph of the first machine's model.
376
381
"""
377
- _LOGGER .info ('Returning graph of the first model. In future releases, this ' +
382
+ _LOGGER .info ('Returning graph of the first model. In future releases, this '
378
383
'method will return a combined graph of all models.' )
379
384
return self ._get_graph (self .models [0 ], title , force_new , show_roi )
380
385
@@ -439,9 +444,8 @@ def set_node_state(self, graph, node_name, state='default'):
439
444
func = self .set_graph_style
440
445
func (graph , node , state )
441
446
442
- @staticmethod
443
- def _graph_roi (model ):
444
- graph = model .graph
447
+ def _graph_roi (self , model ):
448
+ graph = model .get_graph ()
445
449
filtered = graph .copy ()
446
450
447
451
kept_nodes = set ()
0 commit comments