@@ -69,7 +69,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
6969 if not parent_nodes :
7070 # root node: generate independent samples
7171 node_samples = [
72- {"__batch_idx" : batch_idx , f"__{ node } _idx" : i } | self ._call_sampling_fn (sampling_fn , {})
72+ {"__batch_idx" : batch_idx , f"__{ node } _idx" : i } | self ._call_sample_fn (sampling_fn , {})
7373 for i in range (1 , reps + 1 )
7474 ]
7575 else :
@@ -86,7 +86,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
8686 [
8787 index_entries
8888 | {f"__{ node } _idx" : i }
89- | self ._call_sampling_fn (sampling_fn , sampling_fn_input )
89+ | self ._call_sample_fn (sampling_fn , sampling_fn_input )
9090 for i in range (1 , reps + 1 )
9191 ]
9292 )
@@ -169,12 +169,12 @@ def _output_shape(self, samples, variable):
169169
170170 return tuple (output_shape )
171171
172- def _call_sampling_fn (self , sampling_fn , args ):
173- signature = inspect .signature (sampling_fn )
172+ def _call_sample_fn (self , sample_fn , args ):
173+ signature = inspect .signature (sample_fn )
174174 fn_args = signature .parameters
175175 accepted_args = {k : v for k , v in args .items () if k in fn_args }
176176
177- return sampling_fn (** accepted_args )
177+ return sample_fn (** accepted_args )
178178
179179
180180def sorted_ancestors (graph , node ):
0 commit comments