@@ -215,6 +215,26 @@ def print_summary(
215215
216216 table = rich .table .Table (* columns , width = line_length , show_lines = True )
217217
218+ def get_connections (layer ):
219+ connections = ""
220+ for node in layer ._inbound_nodes :
221+ if relevant_nodes and node not in relevant_nodes :
222+ # node is not part of the current network
223+ continue
224+ for kt in node .input_tensors :
225+ keras_history = kt ._keras_history
226+ inbound_layer = keras_history .operation
227+ node_index = highlight_number (keras_history .node_index )
228+ tensor_index = highlight_number (keras_history .tensor_index )
229+ if connections :
230+ connections += ", "
231+ connections += (
232+ f"{ inbound_layer .name } [{ node_index } ][{ tensor_index } ]"
233+ )
234+ if not connections :
235+ connections = "-"
236+ return connections
237+
218238 def get_layer_fields (layer , prefix = "" ):
219239 output_shape = format_layer_shape (layer )
220240 name = prefix + layer .name
@@ -230,6 +250,8 @@ def get_layer_fields(layer, prefix=""):
230250 params = highlight_number (f"{ layer .count_params ():,} " )
231251
232252 fields = [name , output_shape , params ]
253+ if not sequential_like :
254+ fields .append (get_connections (layer ))
233255 if show_trainable :
234256 fields .append (
235257 bold_text ("Y" , color = 34 )
@@ -238,37 +260,13 @@ def get_layer_fields(layer, prefix=""):
238260 )
239261 return fields
240262
241- def get_connections (layer ):
242- connections = ""
243- for node in layer ._inbound_nodes :
244- if relevant_nodes and node not in relevant_nodes :
245- # node is not part of the current network
246- continue
247- for kt in node .input_tensors :
248- keras_history = kt ._keras_history
249- inbound_layer = keras_history .operation
250- node_index = highlight_number (keras_history .node_index )
251- tensor_index = highlight_number (keras_history .tensor_index )
252- if connections :
253- connections += ", "
254- connections += (
255- f"{ inbound_layer .name } [{ node_index } ][{ tensor_index } ]"
256- )
257- if not connections :
258- connections = "-"
259- return connections
260-
261263 def print_layer (layer , nested_level = 0 ):
262264 if nested_level :
263265 prefix = " " * nested_level + "└" + " "
264266 else :
265267 prefix = ""
266268
267269 fields = get_layer_fields (layer , prefix = prefix )
268- if not sequential_like :
269- fields .append (get_connections (layer ))
270- if show_trainable :
271- fields .append ("Y" if layer .trainable else "N" )
272270
273271 rows = [fields ]
274272 if expand_nested and hasattr (layer , "layers" ) and layer .layers :
0 commit comments