Skip to content

Commit dbf6b5e

Browse files
committed
Enable None to be passed to bootstrapped fit
1 parent 0e435b5 commit dbf6b5e

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

econml/bootstrap.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,13 @@ def fit(self, *args, **named_args):
5757
def fit(x, *args, **kwargs):
5858
x.fit(*args, **kwargs)
5959
return x # Explicitly return x in case fit fails to return its target
60+
61+
def convertArg(arg, inds):
62+
return arg[inds] if arg is not None else None
6063
self._instances = Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=3)(
61-
delayed(fit)(obj, *[arg[inds] for arg in args], **{arg: named_args[arg][inds] for arg in named_args})
64+
delayed(fit)(obj,
65+
*[convertArg(arg, inds) for arg in args],
66+
**{arg: convertArg(named_args[arg], inds) for arg in named_args})
6267
for obj, inds in zip(self._instances, indices)
6368
)
6469
return self

0 commit comments

Comments
 (0)