Skip to content

Commit 307269a

Browse files
committed
fix batched_call docs
1 parent 4a98bed commit 307269a

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

bayesflow/utils/dictutils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,19 @@ def convert_args(f, *args, **kwargs) -> tuple[any, ...]:
4747

4848
def batched_call(f: callable, batch_shape: Shape, *args: Tensor, **kwargs: Tensor):
4949
"""Call f, automatically vectorizing to batch_shape if required.
50-
f may accept any number of tensor or numpy array arguments.
51-
:param f:
52-
:param batch_shape:
53-
:param args:
54-
:param kwargs:
55-
:return:
50+
51+
:param f: The function to call.
52+
May accept any number of tensor or numpy array arguments.
53+
Must return a dictionary of tensors or numpy arrays.
54+
55+
:param batch_shape: The shape of the batch. If f is not already batched, it will be called
56+
prod(batch_shape) times.
57+
58+
:param args: Positional arguments to f
59+
60+
:param kwargs: Keyword arguments to f
61+
62+
:return: A dictionary of batched tensors or numpy arrays.
5663
"""
5764
try:
5865
# already batched

0 commit comments

Comments
 (0)