Skip to content

Commit a033a4b

Browse files
committed
Use static shape in join_nonshared_inputs
1 parent e9e850c commit a033a4b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pymc/pytensorf.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -563,13 +563,13 @@ def join_nonshared_inputs(
563563
raise ValueError("Empty list of input variables.")
564564

565565
raveled_inputs = pt.concatenate([var.ravel() for var in inputs])
566+
size = sum(point[var_name].size for var_name in point)
566567

567568
if not make_inputs_shared:
568-
tensor_type = raveled_inputs.type
569-
joined_inputs = tensor_type("joined_inputs")
569+
joined_inputs = pt.tensor("joined_inputs", shape=(size,), dtype=raveled_inputs.dtype)
570570
else:
571571
joined_values = np.concatenate([point[var.name].ravel() for var in inputs])
572-
joined_inputs = pytensor.shared(joined_values, "joined_inputs")
572+
joined_inputs = pytensor.shared(joined_values, "joined_inputs", shape=(size,))
573573

574574
if pytensor.config.compute_test_value != "off":
575575
joined_inputs.tag.test_value = raveled_inputs.tag.test_value

0 commit comments

Comments
 (0)