@@ -170,13 +170,13 @@ def print_summary(
170170 break
171171
172172 if sequential_like :
173- line_length = line_length or 84
173+ default_line_length = 84
174174 positions = positions or [0.45 , 0.84 , 1.0 ]
175175 # header names for the different log elements
176176 header = ["Layer (type)" , "Output Shape" , "Param #" ]
177177 alignment = ["left" , "left" , "right" ]
178178 else :
179- line_length = line_length or 108
179+ default_line_length = 108
180180 positions = positions or [0.3 , 0.56 , 0.70 , 1.0 ]
181181 # header names for the different log elements
182182 header = ["Layer (type)" , "Output Shape" , "Param #" , "Connected to" ]
@@ -186,13 +186,16 @@ def print_summary(
186186 relevant_nodes += v
187187
188188 if show_trainable :
189- line_length += 8
189+ default_line_length += 8
190190 positions = [p * 0.88 for p in positions ] + [1.0 ]
191191 header .append ("Trainable" )
192192 alignment .append ("center" )
193193
194194 # Compute columns widths
195- line_length = min (line_length , shutil .get_terminal_size ().columns - 4 )
195+ default_line_length = min (
196+ default_line_length , shutil .get_terminal_size ().columns - 4
197+ )
198+ line_length = line_length or default_line_length
196199 column_widths = []
197200 current = 0
198201 for pos in positions :
0 commit comments