Skip to content

Commit cac5773

Browse files
committed
improve graph generation of novel hsm
- for both graphviz and pygraphviz use dest if state is an enum improve graph generation for nested graphs (#413)
1 parent 9237b9e commit cac5773

File tree

4 files changed

+98
-93
lines changed

4 files changed

+98
-93
lines changed

Diff for: transitions/extensions/diagrams.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
import logging
77
from functools import partial
8+
import copy
89

910
_LOGGER = logging.getLogger(__name__)
1011
_LOGGER.addHandler(logging.NullHandler())
@@ -26,8 +27,10 @@ class TransitionGraphSupport(Transition):
2627
def _change_state(self, event_data):
2728
graph = event_data.machine.model_graphs[event_data.model]
2829
graph.reset_styling()
29-
graph.set_previous_transition(self.source, self.dest)
30+
graph.set_previous_transition(self.source, self.dest, event_data.event.name)
3031
_super(TransitionGraphSupport, self)._change_state(event_data) # pylint: disable=protected-access
32+
for state in _flatten(listify(getattr(event_data.model, event_data.machine.model_attribute))):
33+
graph.set_node_style(self.dest if hasattr(state, 'name') else state, 'active')
3134

3235

3336
class GraphMachine(MarkupMachine):
@@ -88,7 +91,13 @@ class GraphMachine(MarkupMachine):
8891
'': {},
8992
'default': {
9093
'color': 'black',
91-
'fillcolor': 'white'
94+
'fillcolor': 'white',
95+
'style': 'solid'
96+
},
97+
'parallel': {
98+
'color': 'black',
99+
'fillcolor': 'white',
100+
'style': 'dotted'
92101
},
93102
'previous': {
94103
'color': 'blue',
@@ -263,3 +272,47 @@ def _get_global_name(self, path):
263272
return self._get_global_name(path)
264273
else:
265274
return self.machine.get_global_name()
275+
276+
def _get_elements(self):
277+
states = []
278+
transitions = []
279+
try:
280+
markup = self.machine.get_markup_config()
281+
q = [([], markup)]
282+
283+
while q:
284+
prefix, scope = q.pop(0)
285+
for transition in scope.get('transitions', []):
286+
if prefix:
287+
t = copy.copy(transition)
288+
t['source'] = self.machine.state_cls.separator.join(prefix + [t['source']])
289+
t['dest'] = self.machine.state_cls.separator.join(prefix + [t['dest']])
290+
else:
291+
t = transition
292+
transitions.append(t)
293+
for state in scope.get('children', []) + scope.get('states', []):
294+
if not prefix:
295+
s = state
296+
states.append(s)
297+
298+
ini = state.get('initial', [])
299+
if not isinstance(ini, list):
300+
ini = ini.name if hasattr(ini, 'name') else ini
301+
t = dict(trigger='',
302+
source=self.machine.state_cls.separator.join(prefix + [state['name']]) + '_anchor',
303+
dest=self.machine.state_cls.separator.join(prefix + [state['name'], ini]))
304+
transitions.append(t)
305+
if state.get('children', []):
306+
q.append((prefix + [state['name']], state))
307+
except KeyError as e:
308+
_LOGGER.error("Graph creation incomplete!")
309+
return states, transitions
310+
311+
312+
def _flatten(item):
313+
for elem in item:
314+
if isinstance(elem, (list, tuple, set)):
315+
for res in _flatten(elem):
316+
yield res
317+
else:
318+
yield elem

Diff for: transitions/extensions/diagrams_graphviz.py

+28-50
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import copy
1414

1515
from .diagrams import BaseGraph
16+
from ..core import listify
1617
try:
1718
import graphviz as pgv
1819
except ImportError: # pragma: no cover
@@ -37,10 +38,9 @@ def __init__(self, machine, title=None):
3738
self.reset_styling()
3839
_super(Graph, self).__init__(machine, title)
3940

40-
def set_previous_transition(self, src, dst):
41+
def set_previous_transition(self, src, dst, key=None):
4142
self.custom_styles['edge'][src][dst] = 'previous'
4243
self.set_node_style(src, 'previous')
43-
self.set_node_style(dst, 'active')
4444

4545
def set_node_style(self, state, style):
4646
self.custom_styles['node'][state] = style
@@ -76,49 +76,24 @@ def generate(self, title=None, roi_state=None):
7676
if not pgv: # pragma: no cover
7777
raise Exception('AGraph diagram requires graphviz')
7878

79-
title = '' if not title else title
79+
title = self.machine.title if not title else title
8080

8181
fsm_graph = pgv.Digraph(name=title, node_attr=self.machine.style_attributes['node']['default'],
8282
edge_attr=self.machine.style_attributes['edge']['default'],
8383
graph_attr=self.machine.style_attributes['graph']['default'])
8484
fsm_graph.graph_attr.update(**self.machine.machine_attributes)
85+
fsm_graph.graph_attr['label'] = title
8586
# For each state, draw a circle
86-
try:
87-
markup = self.machine.get_markup_config()
88-
q = [([], markup)]
89-
states = []
90-
transitions = []
91-
while q:
92-
prefix, scope = q.pop(0)
93-
for transition in scope.get('transitions', []):
94-
if prefix:
95-
t = copy.copy(transition)
96-
t['source'] = self.machine.state_cls.separator.join(prefix + [t['source']])
97-
t['dest'] = self.machine.state_cls.separator.join(prefix + [t['dest']])
98-
else:
99-
t = transition
100-
transitions.append(t)
101-
for state in scope.get('states', []):
102-
if prefix:
103-
s = copy.copy(state)
104-
s['name'] = self.machine.state_cls.separator.join(prefix + [s['name']])
105-
else:
106-
s = state
107-
states.append(s)
108-
if state.get('children', []):
109-
q.append((prefix + [state['name']], state))
110-
111-
if roi_state:
112-
transitions = [t for t in transitions
113-
if t['source'] == roi_state or self.custom_styles['edge'][t['source']][t['dest']]]
114-
state_names = [t for trans in transitions
115-
for t in [trans['source'], trans.get('dest', trans['source'])]]
116-
state_names += [k for k, style in self.custom_styles['node'].items() if style]
117-
states = _filter_states(states, state_names, self.machine.state_cls)
118-
self._add_nodes(states, fsm_graph)
119-
self._add_edges(transitions, fsm_graph)
120-
except KeyError:
121-
_LOGGER.error("Graph creation incomplete!")
87+
states, transitions = self._get_elements()
88+
if roi_state:
89+
transitions = [t for t in transitions
90+
if t['source'] == roi_state or self.custom_styles['edge'][t['source']][t['dest']]]
91+
state_names = [t for trans in transitions
92+
for t in [trans['source'], trans.get('dest', trans['source'])]]
93+
state_names += [k for k, style in self.custom_styles['node'].items() if style]
94+
states = _filter_states(states, state_names, self.machine.state_cls)
95+
self._add_nodes(states, fsm_graph)
96+
self._add_edges(transitions, fsm_graph)
12297
setattr(fsm_graph, 'draw', partial(self.draw, fsm_graph))
12398
return fsm_graph
12499

@@ -152,41 +127,44 @@ def __init__(self, *args, **kwargs):
152127
self._cluster_states = []
153128
_super(NestedGraph, self).__init__(*args, **kwargs)
154129

155-
def set_previous_transition(self, src, dst):
130+
def set_previous_transition(self, src, dst, key=None):
156131
src_name = self._get_global_name(src.split(self.machine.state_cls.separator))
157132
dst_name = self._get_global_name(dst.split(self.machine.state_cls.separator))
158-
_super(NestedGraph, self).set_previous_transition(src_name, dst_name)
133+
_super(NestedGraph, self).set_previous_transition(src_name, dst_name, key)
159134

160-
def _add_nodes(self, states, container, prefix=''):
135+
def _add_nodes(self, states, container, prefix='', default_style='default'):
161136

162137
for state in states:
163138
name = prefix + state['name']
164139
label = self._convert_state_attributes(state)
165140

166-
if 'children' in state:
141+
if state.get('children', []):
167142
cluster_name = "cluster_" + name
168143
with container.subgraph(name=cluster_name,
169144
graph_attr=self.machine.style_attributes['graph']['default']) as sub:
170-
style = self.custom_styles['node'][name]
145+
style = self.custom_styles['node'][name] or default_style
171146
sub.graph_attr.update(label=label, rank='source', **self.machine.style_attributes['graph'][style])
172147
self._cluster_states.append(name)
148+
is_parallel = isinstance(state.get('initial', ''), list)
149+
width = '0.0' if is_parallel else '0.1'
173150
with sub.subgraph(name=cluster_name + '_root',
174151
graph_attr={'label': '', 'color': 'None', 'rank': 'min'}) as root:
175-
root.node(name + "_anchor", shape='point', fillcolor='black', width='0.1')
176-
self._add_nodes(state['children'], sub, prefix=prefix + state['name'] + self.machine.state_cls.separator)
152+
root.node(name + "_anchor", shape='point', fillcolor='black', width=width)
153+
self._add_nodes(state['children'], sub, default_style='parallel' if is_parallel else 'default',
154+
prefix=prefix + state['name'] + self.machine.state_cls.separator)
177155
else:
178-
style = self.custom_styles['node'][name]
156+
style = self.custom_styles['node'][name] or default_style
179157
container.node(name, label=label, **self.machine.style_attributes['node'][style])
180158

181-
def _add_edges(self, transitions, container):
159+
def _add_edges(self, transitions, container, prefix=''):
182160
edges_attr = defaultdict(lambda: defaultdict(dict))
183161

184162
for transition in transitions:
185163
# enable customizable labels
186164
label_pos = 'label'
187-
src = transition['source']
165+
src = prefix + transition['source']
188166
try:
189-
dst = transition['dest']
167+
dst = prefix + transition['dest']
190168
except KeyError:
191169
dst = src
192170
if edges_attr[src][dst]:

Diff for: transitions/extensions/diagrams_pygraphviz.py

+14-40
Original file line numberDiff line numberDiff line change
@@ -61,35 +61,9 @@ def generate(self, title=None):
6161
self.fsm_graph = pgv.AGraph(label=title, **self.machine.machine_attributes)
6262
self.fsm_graph.node_attr.update(self.machine.style_attributes['node']['default'])
6363
self.fsm_graph.edge_attr.update(self.machine.style_attributes['edge']['default'])
64-
65-
# For each state, draw a circle
66-
markup = self.machine.get_markup_config()
67-
q = [([], markup)]
68-
states = []
69-
transitions = []
70-
while q:
71-
prefix, scope = q.pop(0)
72-
for transition in scope.get('transitions', []):
73-
if prefix:
74-
t = copy.copy(transition)
75-
t['source'] = self.machine.state_cls.separator.join(prefix + [t['source']])
76-
t['dest'] = self.machine.state_cls.separator.join(prefix + [t['dest']])
77-
else:
78-
t = transition
79-
transitions.append(t)
80-
for state in scope.get('states', []):
81-
if prefix:
82-
s = copy.copy(state)
83-
s['name'] = self.machine.state_cls.separator.join(prefix + [s['name']])
84-
else:
85-
s = state
86-
states.append(s)
87-
if state.get('children', []):
88-
q.append((prefix + [state['name']], state))
89-
64+
states, transitions = self._get_elements()
9065
self._add_nodes(states, self.fsm_graph)
9166
self._add_edges(transitions, self.fsm_graph)
92-
9367
setattr(self.fsm_graph, 'style_attributes', self.machine.style_attributes)
9468

9569
return self.fsm_graph
@@ -132,7 +106,7 @@ def set_node_style(self, state, style):
132106
style_attr = self.fsm_graph.style_attributes.get('node', {}).get(style)
133107
node.attr.update(style_attr)
134108

135-
def set_previous_transition(self, src, dst):
109+
def set_previous_transition(self, src, dst, key=None):
136110
try:
137111
edge = self.fsm_graph.get_edge(src, dst)
138112
except KeyError:
@@ -164,27 +138,26 @@ def __init__(self, *args, **kwargs):
164138
_super(NestedGraph, self).__init__(*args, **kwargs)
165139
# self.style_attributes['edge']['default']['minlen'] = 2
166140

167-
def _add_nodes(self, states, container, prefix=''):
141+
def _add_nodes(self, states, container, prefix='', default_style='default'):
168142
for state in states:
169143
name = prefix + state['name']
170144
label = self._convert_state_attributes(state)
171145

172146
if 'children' in state:
173147
cluster_name = "cluster_" + name
148+
is_parallel = isinstance(state.get('initial', ''), list)
174149
sub = container.add_subgraph(name=cluster_name, label=label, rank='source',
175-
**self.machine.style_attributes['graph']['default'])
150+
**self.machine.style_attributes['graph'][default_style])
176151
root_container = sub.add_subgraph(name=cluster_name + '_root', label='', color=None, rank='min')
177-
# child_container = sub.add_subgraph(name=cluster_name + '_child', label='', color=None)
178-
root_container.add_node(name + "_anchor", shape='point', fillcolor='black', width='0.1')
179-
self._add_nodes(state['children'], sub, prefix=prefix + state['name'] + NestedState.separator)
152+
width = '0' if is_parallel else '0.1'
153+
root_container.add_node(name + "_anchor", shape='point', fillcolor='black', width=width)
154+
self._add_nodes(state['children'], sub, prefix=prefix + state['name'] + NestedState.separator,
155+
default_style='parallel' if is_parallel else 'default')
180156
else:
181157
container.add_node(name, label=label, shape=self.machine.style_attributes['node']['default']['shape'])
182158

183159
def _add_edges(self, transitions, container):
184160

185-
# for sub in container.subgraphs_iter():
186-
# events = self._add_edges(transitions, sub)
187-
188161
for transition in transitions:
189162
# enable customizable labels
190163
label_pos = 'label'
@@ -232,28 +205,29 @@ def set_node_style(self, state, style):
232205
style_attr = self.fsm_graph.style_attributes.get('graph', {}).get(style)
233206
subgraph.graph_attr.update(style_attr)
234207

235-
def set_previous_transition(self, src, dst):
208+
def set_previous_transition(self, src, dst, key=None):
236209
src = self._get_global_name(src.split(self.machine.state_cls.separator))
237210
dst = self._get_global_name(dst.split(self.machine.state_cls.separator))
211+
edge_attr = self.fsm_graph.style_attributes.get('edge', {}).get('previous').copy()
238212
try:
239213
edge = self.fsm_graph.get_edge(src, dst)
240214
except KeyError:
241215
_src = src
242216
_dst = dst
243217
if _get_subgraph(self.fsm_graph, 'cluster_' + src):
218+
edge_attr['ltail'] = 'cluster_' + src
244219
_src += '_anchor'
245220
if _get_subgraph(self.fsm_graph, 'cluster_' + dst):
221+
edge_attr['lhead'] = "cluster_" + dst
246222
_dst += '_anchor'
247223
try:
248224
edge = self.fsm_graph.get_edge(_src, _dst)
249225
except KeyError:
250226
self.fsm_graph.add_edge(_src, _dst)
251227
edge = self.fsm_graph.get_edge(_src, _dst)
252228

253-
style_attr = self.fsm_graph.style_attributes.get('edge', {}).get('previous')
254-
edge.attr.update(style_attr)
229+
edge.attr.update(edge_attr)
255230
self.set_node_style(src, 'previous')
256-
self.set_node_style(dst, 'active')
257231

258232

259233
def _get_subgraph(graph, name):

Diff for: transitions/extensions/factory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,5 +119,5 @@ class HierarchicalAsyncGraphMachine(GraphMachine, HierarchicalAsyncMachine):
119119
(False, False, False, True): AsyncMachine,
120120
(True, False, False, True): AsyncGraphMachine,
121121
(False, True, False, True): HierarchicalAsyncMachine,
122-
(False, True, False, True): HierarchicalAsyncGraphMachine
122+
(True, True, False, True): HierarchicalAsyncGraphMachine
123123
}

0 commit comments

Comments
 (0)