Skip to content

Commit c54f806

Browse files
authored
Fix extra summary column when show_trainable=True (#435)
Silly bug, we were literally just adding the trainable field twice
1 parent e65d025 commit c54f806

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

keras_core/utils/summary_utils.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,26 @@ def print_summary(
215215

216216
table = rich.table.Table(*columns, width=line_length, show_lines=True)
217217

218+
def get_connections(layer):
219+
connections = ""
220+
for node in layer._inbound_nodes:
221+
if relevant_nodes and node not in relevant_nodes:
222+
# node is not part of the current network
223+
continue
224+
for kt in node.input_tensors:
225+
keras_history = kt._keras_history
226+
inbound_layer = keras_history.operation
227+
node_index = highlight_number(keras_history.node_index)
228+
tensor_index = highlight_number(keras_history.tensor_index)
229+
if connections:
230+
connections += ", "
231+
connections += (
232+
f"{inbound_layer.name}[{node_index}][{tensor_index}]"
233+
)
234+
if not connections:
235+
connections = "-"
236+
return connections
237+
218238
def get_layer_fields(layer, prefix=""):
219239
output_shape = format_layer_shape(layer)
220240
name = prefix + layer.name
@@ -230,6 +250,8 @@ def get_layer_fields(layer, prefix=""):
230250
params = highlight_number(f"{layer.count_params():,}")
231251

232252
fields = [name, output_shape, params]
253+
if not sequential_like:
254+
fields.append(get_connections(layer))
233255
if show_trainable:
234256
fields.append(
235257
bold_text("Y", color=34)
@@ -238,37 +260,13 @@ def get_layer_fields(layer, prefix=""):
238260
)
239261
return fields
240262

241-
def get_connections(layer):
242-
connections = ""
243-
for node in layer._inbound_nodes:
244-
if relevant_nodes and node not in relevant_nodes:
245-
# node is not part of the current network
246-
continue
247-
for kt in node.input_tensors:
248-
keras_history = kt._keras_history
249-
inbound_layer = keras_history.operation
250-
node_index = highlight_number(keras_history.node_index)
251-
tensor_index = highlight_number(keras_history.tensor_index)
252-
if connections:
253-
connections += ", "
254-
connections += (
255-
f"{inbound_layer.name}[{node_index}][{tensor_index}]"
256-
)
257-
if not connections:
258-
connections = "-"
259-
return connections
260-
261263
def print_layer(layer, nested_level=0):
262264
if nested_level:
263265
prefix = " " * nested_level + "└" + " "
264266
else:
265267
prefix = ""
266268

267269
fields = get_layer_fields(layer, prefix=prefix)
268-
if not sequential_like:
269-
fields.append(get_connections(layer))
270-
if show_trainable:
271-
fields.append("Y" if layer.trainable else "N")
272270

273271
rows = [fields]
274272
if expand_nested and hasattr(layer, "layers") and layer.layers:

0 commit comments

Comments
 (0)