13
13
import copy
14
14
15
15
from .diagrams import BaseGraph
16
+ from ..core import listify
16
17
try :
17
18
import graphviz as pgv
18
19
except ImportError : # pragma: no cover
@@ -37,10 +38,9 @@ def __init__(self, machine, title=None):
37
38
self .reset_styling ()
38
39
_super (Graph , self ).__init__ (machine , title )
39
40
40
- def set_previous_transition (self , src , dst ):
41
+ def set_previous_transition (self , src , dst , key = None ):
41
42
self .custom_styles ['edge' ][src ][dst ] = 'previous'
42
43
self .set_node_style (src , 'previous' )
43
- self .set_node_style (dst , 'active' )
44
44
45
45
def set_node_style (self , state , style ):
46
46
self .custom_styles ['node' ][state ] = style
@@ -76,49 +76,24 @@ def generate(self, title=None, roi_state=None):
76
76
if not pgv : # pragma: no cover
77
77
raise Exception ('AGraph diagram requires graphviz' )
78
78
79
- title = '' if not title else title
79
+ title = self . machine . title if not title else title
80
80
81
81
fsm_graph = pgv .Digraph (name = title , node_attr = self .machine .style_attributes ['node' ]['default' ],
82
82
edge_attr = self .machine .style_attributes ['edge' ]['default' ],
83
83
graph_attr = self .machine .style_attributes ['graph' ]['default' ])
84
84
fsm_graph .graph_attr .update (** self .machine .machine_attributes )
85
+ fsm_graph .graph_attr ['label' ] = title
85
86
# For each state, draw a circle
86
- try :
87
- markup = self .machine .get_markup_config ()
88
- q = [([], markup )]
89
- states = []
90
- transitions = []
91
- while q :
92
- prefix , scope = q .pop (0 )
93
- for transition in scope .get ('transitions' , []):
94
- if prefix :
95
- t = copy .copy (transition )
96
- t ['source' ] = self .machine .state_cls .separator .join (prefix + [t ['source' ]])
97
- t ['dest' ] = self .machine .state_cls .separator .join (prefix + [t ['dest' ]])
98
- else :
99
- t = transition
100
- transitions .append (t )
101
- for state in scope .get ('states' , []):
102
- if prefix :
103
- s = copy .copy (state )
104
- s ['name' ] = self .machine .state_cls .separator .join (prefix + [s ['name' ]])
105
- else :
106
- s = state
107
- states .append (s )
108
- if state .get ('children' , []):
109
- q .append ((prefix + [state ['name' ]], state ))
110
-
111
- if roi_state :
112
- transitions = [t for t in transitions
113
- if t ['source' ] == roi_state or self .custom_styles ['edge' ][t ['source' ]][t ['dest' ]]]
114
- state_names = [t for trans in transitions
115
- for t in [trans ['source' ], trans .get ('dest' , trans ['source' ])]]
116
- state_names += [k for k , style in self .custom_styles ['node' ].items () if style ]
117
- states = _filter_states (states , state_names , self .machine .state_cls )
118
- self ._add_nodes (states , fsm_graph )
119
- self ._add_edges (transitions , fsm_graph )
120
- except KeyError :
121
- _LOGGER .error ("Graph creation incomplete!" )
87
+ states , transitions = self ._get_elements ()
88
+ if roi_state :
89
+ transitions = [t for t in transitions
90
+ if t ['source' ] == roi_state or self .custom_styles ['edge' ][t ['source' ]][t ['dest' ]]]
91
+ state_names = [t for trans in transitions
92
+ for t in [trans ['source' ], trans .get ('dest' , trans ['source' ])]]
93
+ state_names += [k for k , style in self .custom_styles ['node' ].items () if style ]
94
+ states = _filter_states (states , state_names , self .machine .state_cls )
95
+ self ._add_nodes (states , fsm_graph )
96
+ self ._add_edges (transitions , fsm_graph )
122
97
setattr (fsm_graph , 'draw' , partial (self .draw , fsm_graph ))
123
98
return fsm_graph
124
99
@@ -152,41 +127,44 @@ def __init__(self, *args, **kwargs):
152
127
self ._cluster_states = []
153
128
_super (NestedGraph , self ).__init__ (* args , ** kwargs )
154
129
155
- def set_previous_transition (self , src , dst ):
130
+ def set_previous_transition (self , src , dst , key = None ):
156
131
src_name = self ._get_global_name (src .split (self .machine .state_cls .separator ))
157
132
dst_name = self ._get_global_name (dst .split (self .machine .state_cls .separator ))
158
- _super (NestedGraph , self ).set_previous_transition (src_name , dst_name )
133
+ _super (NestedGraph , self ).set_previous_transition (src_name , dst_name , key )
159
134
160
- def _add_nodes (self , states , container , prefix = '' ):
135
+ def _add_nodes (self , states , container , prefix = '' , default_style = 'default' ):
161
136
162
137
for state in states :
163
138
name = prefix + state ['name' ]
164
139
label = self ._convert_state_attributes (state )
165
140
166
- if 'children' in state :
141
+ if state . get ( 'children' , []) :
167
142
cluster_name = "cluster_" + name
168
143
with container .subgraph (name = cluster_name ,
169
144
graph_attr = self .machine .style_attributes ['graph' ]['default' ]) as sub :
170
- style = self .custom_styles ['node' ][name ]
145
+ style = self .custom_styles ['node' ][name ] or default_style
171
146
sub .graph_attr .update (label = label , rank = 'source' , ** self .machine .style_attributes ['graph' ][style ])
172
147
self ._cluster_states .append (name )
148
+ is_parallel = isinstance (state .get ('initial' , '' ), list )
149
+ width = '0.0' if is_parallel else '0.1'
173
150
with sub .subgraph (name = cluster_name + '_root' ,
174
151
graph_attr = {'label' : '' , 'color' : 'None' , 'rank' : 'min' }) as root :
175
- root .node (name + "_anchor" , shape = 'point' , fillcolor = 'black' , width = '0.1' )
176
- self ._add_nodes (state ['children' ], sub , prefix = prefix + state ['name' ] + self .machine .state_cls .separator )
152
+ root .node (name + "_anchor" , shape = 'point' , fillcolor = 'black' , width = width )
153
+ self ._add_nodes (state ['children' ], sub , default_style = 'parallel' if is_parallel else 'default' ,
154
+ prefix = prefix + state ['name' ] + self .machine .state_cls .separator )
177
155
else :
178
- style = self .custom_styles ['node' ][name ]
156
+ style = self .custom_styles ['node' ][name ] or default_style
179
157
container .node (name , label = label , ** self .machine .style_attributes ['node' ][style ])
180
158
181
- def _add_edges (self , transitions , container ):
159
+ def _add_edges (self , transitions , container , prefix = '' ):
182
160
edges_attr = defaultdict (lambda : defaultdict (dict ))
183
161
184
162
for transition in transitions :
185
163
# enable customizable labels
186
164
label_pos = 'label'
187
- src = transition ['source' ]
165
+ src = prefix + transition ['source' ]
188
166
try :
189
- dst = transition ['dest' ]
167
+ dst = prefix + transition ['dest' ]
190
168
except KeyError :
191
169
dst = src
192
170
if edges_attr [src ][dst ]:
0 commit comments