Skip to content

Commit 13295ee

Browse files
authored
add replace nan adapter (#459)
* add replace nan adapter * improved naming * _mask as additional key * update test * improve * fix serializable * changed name to return_mask * add mask naming
1 parent 38186ec commit 13295ee

File tree

4 files changed

+145
-0
lines changed

4 files changed

+145
-0
lines changed

bayesflow/adapters/adapter.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Ungroup,
3232
RandomSubsample,
3333
Take,
34+
NanToNum,
3435
)
3536
from .transforms.filter_transform import Predicate
3637

@@ -956,3 +957,34 @@ def to_dict(self):
956957
transform = ToDict()
957958
self.transforms.append(transform)
958959
return self
960+
961+
def nan_to_num(
962+
self,
963+
keys: str | Sequence[str],
964+
default_value: float = 0.0,
965+
return_mask: bool = False,
966+
mask_prefix: str = "mask",
967+
):
968+
"""
969+
Append :py:class:`~bf.adapters.transforms.NanToNum` transform to the adapter.
970+
971+
Parameters
972+
----------
973+
keys : str or sequence of str
974+
The names of the variables to clean / mask.
975+
default_value : float
976+
Value to substitute wherever data is NaN. Defaults to 0.0.
977+
return_mask : bool
978+
If True, encode a binary missingness mask alongside the data. Defaults to False.
979+
mask_prefix : str
980+
Prefix for the mask key in the output dictionary. Defaults to 'mask_'. If the mask key already exists,
981+
a ValueError is raised to avoid overwriting existing masks.
982+
"""
983+
if isinstance(keys, str):
984+
keys = [keys]
985+
986+
for key in keys:
987+
self.transforms.append(
988+
NanToNum(key=key, default_value=default_value, return_mask=return_mask, mask_prefix=mask_prefix)
989+
)
990+
return self

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .random_subsample import RandomSubsample
3030
from .take import Take
3131
from .ungroup import Ungroup
32+
from .nan_to_num import NanToNum
3233

3334
from ...utils._docs import _add_imports_to_all
3435

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import numpy as np
2+
3+
from bayesflow.utils.serialization import serializable, serialize
4+
from .transform import Transform
5+
6+
7+
@serializable("bayesflow.adapters")
8+
class NanToNum(Transform):
9+
"""
10+
Replace NaNs with a default value, and optionally encode a missing-data mask as a separate output key.
11+
12+
This is based on "Missing data in amortized simulation-based neural posterior estimation" by Wang et al. (2024).
13+
14+
Parameters
15+
----------
16+
default_value : float
17+
Value to substitute wherever data is NaN.
18+
return_mask : bool, default=False
19+
If True, a mask array will be returned under a new key.
20+
mask_prefix : str, default='mask_'
21+
Prefix for the mask key in the output dictionary.
22+
"""
23+
24+
def __init__(self, key: str, default_value: float = 0.0, return_mask: bool = False, mask_prefix: str = "mask"):
25+
super().__init__()
26+
self.key = key
27+
self.default_value = default_value
28+
self.return_mask = return_mask
29+
self.mask_prefix = mask_prefix
30+
31+
def get_config(self) -> dict:
32+
return serialize(
33+
{
34+
"key": self.key,
35+
"default_value": self.default_value,
36+
"return_mask": self.return_mask,
37+
"mask_prefix": self.mask_prefix,
38+
}
39+
)
40+
41+
@property
42+
def mask_key(self) -> str:
43+
"""
44+
Key under which the mask will be stored in the output dictionary.
45+
"""
46+
return f"{self.mask_prefix}_{self.key}"
47+
48+
def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
49+
"""
50+
Forward transform: fill NaNs and optionally output mask under 'mask_<key>'.
51+
"""
52+
data = data.copy()
53+
54+
# Check if the mask key already exists in the data
55+
if self.mask_key in data.keys():
56+
raise ValueError(
57+
f"Mask key '{self.mask_key}' already exists in the data. Please choose a different mask_prefix."
58+
)
59+
60+
# Identify NaNs and fill with default value
61+
mask = np.isnan(data[self.key])
62+
data[self.key] = np.nan_to_num(data[self.key], copy=False, nan=self.default_value)
63+
64+
if not self.return_mask:
65+
return data
66+
67+
# Prepare mask array (1 for valid, 0 for NaN)
68+
mask_array = (~mask).astype(np.int8)
69+
70+
# Return both the filled data and the mask under separate keys
71+
data[self.mask_key] = mask_array
72+
return data
73+
74+
def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
75+
"""
76+
Inverse transform: restore NaNs using the mask under 'mask_<key>'.
77+
"""
78+
data = data.copy()
79+
80+
# Retrieve mask and values to reconstruct NaNs
81+
values = data[self.key]
82+
83+
if not self.return_mask:
84+
values[values == self.default_value] = np.nan # we assume default_value is not in data
85+
else:
86+
mask_array = data[self.mask_key].astype(bool)
87+
# Put NaNs where mask is 0
88+
values[~mask_array] = np.nan
89+
90+
data[self.key] = values
91+
return data

tests/test_adapters/test_adapters.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,27 @@ def test_log_det_jac_exceptions(random_data):
298298
assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac)
299299

300300

301+
def test_nan_to_num():
302+
arr = {"test": np.array([1.0, np.nan, 3.0])}
303+
# test without mask
304+
transform = bf.Adapter().nan_to_num(keys="test", default_value=-1.0, return_mask=False)
305+
out = transform.forward(arr)["test"]
306+
np.testing.assert_array_equal(out, np.array([1.0, -1.0, 3.0]))
307+
308+
# test with mask
309+
arr = {"test": np.array([1.0, np.nan, 3.0]), "test-2d": np.array([[1.0, np.nan], [np.nan, 4.0]])}
310+
transform = bf.Adapter().nan_to_num(keys="test", default_value=0.0, return_mask=True)
311+
out = transform.forward(arr)
312+
np.testing.assert_array_equal(out["test"], np.array([1.0, 0.0, 3.0]))
313+
np.testing.assert_array_equal(out["mask_test"], np.array([1.0, 0.0, 1.0]))
314+
315+
# test two-d array
316+
transform = bf.Adapter().nan_to_num(keys="test-2d", default_value=0.5, return_mask=True, mask_prefix="new_mask")
317+
out = transform.forward(arr)
318+
np.testing.assert_array_equal(out["test-2d"], np.array([[1.0, 0.5], [0.5, 4.0]]))
319+
np.testing.assert_array_equal(out["new_mask_test-2d"], np.array([[1, 0], [0, 1]]))
320+
321+
301322
def test_nnpe(random_data):
302323
# NNPE cannot be integrated into the adapter fixture and its tests since it modifies the input data
303324
# and therefore breaks existing allclose checks

0 commit comments

Comments
 (0)