@@ -1150,12 +1150,12 @@ def _processing_terms(self, term, previous_states_ind):
1150
1150
else :
1151
1151
container_ndim = self .container_ndim_all .get (term , 1 )
1152
1152
shape = input_shape (self .inputs [term ], container_ndim = container_ndim )
1153
- var_ind = range (reduce (lambda x , y : x * y , shape ))
1153
+ var_ind = range (reduce (lambda x , y : x * y , shape , 1 ))
1154
1154
new_keys = [term ]
1155
1155
# checking if the term is in inner_inputs
1156
1156
if term in self .inner_inputs :
1157
1157
# TODO: have to be changed if differ length
1158
- inner_len = [shape [- 1 ]] * reduce (lambda x , y : x * y , shape [:- 1 ])
1158
+ inner_len = [shape [- 1 ]] * reduce (lambda x , y : x * y , shape [:- 1 ], 1 )
1159
1159
# this come from the previous node
1160
1160
outer_ind = self .inner_inputs [term ].ind_l
1161
1161
var_ind_out = itertools .chain .from_iterable (
@@ -1172,7 +1172,7 @@ def _single_op_splits(self, op_single):
1172
1172
self .inputs [op_single ],
1173
1173
container_ndim = self .container_ndim_all .get (op_single , 1 ),
1174
1174
)
1175
- val_ind = range (reduce (lambda x , y : x * y , shape ))
1175
+ val_ind = range (reduce (lambda x , y : x * y , shape , 1 ))
1176
1176
if op_single in self .inner_inputs :
1177
1177
# TODO: have to be changed if differ length
1178
1178
inner_len = [shape [- 1 ]] * reduce (lambda x , y : x * y , shape [:- 1 ], 1 )
0 commit comments