Skip to content

Commit 526db78

Browse files
authored
add FilterOutlier (#229)
1 parent cfc5fdb commit 526db78

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

Diff for: ipsuite/configuration_selection/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Configuration Selection Nodes."""
22

33
from ipsuite.configuration_selection.base import ConfigurationSelection
4+
from ipsuite.configuration_selection.filter import FilterOutlier
45
from ipsuite.configuration_selection.index import IndexSelection
56
from ipsuite.configuration_selection.kernel import KernelSelection
67
from ipsuite.configuration_selection.random import RandomSelection
@@ -20,4 +21,5 @@
2021
"IndexSelection",
2122
"ThresholdSelection",
2223
"SplitSelection",
24+
"FilterOutlier",
2325
]

Diff for: ipsuite/configuration_selection/filter.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import typing as t
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import zntrack
6+
7+
from ipsuite import base
8+
9+
10+
class FilterOutlier(base.ProcessAtoms):
11+
"""Remove outliers from the data based on a given property.
12+
13+
Attributes
14+
----------
15+
key : str, default="energy"
16+
The property to filter on.
17+
threshold : float, default=3
18+
The threshold for filtering in units of standard deviations.
19+
direction : {"above", "below", "both"}, default="both"
20+
The direction to filter in.
21+
"""
22+
23+
key: str = zntrack.params("energy")
24+
threshold: float = zntrack.params(3)
25+
direction: t.Literal["above", "below", "both"] = zntrack.params("both")
26+
27+
filtered_indices: list = zntrack.outs()
28+
histogram: str = zntrack.outs_path(zntrack.nwd / "histogram.png")
29+
30+
def run(self):
31+
values = [x.calc.results[self.key] for x in self.data]
32+
mean = np.mean(values)
33+
std = np.std(values)
34+
35+
if self.direction == "above":
36+
self.filtered_indices = [
37+
i for i, x in enumerate(values) if x > mean + self.threshold * std
38+
]
39+
elif self.direction == "below":
40+
self.filtered_indices = [
41+
i for i, x in enumerate(values) if x < mean - self.threshold * std
42+
]
43+
else:
44+
self.filtered_indices = [
45+
i
46+
for i, x in enumerate(values)
47+
if x > mean + self.threshold * std or x < mean - self.threshold * std
48+
]
49+
50+
fig, ax = plt.subplots(3, figsize=(10, 10))
51+
ax[0].hist(values, bins=100)
52+
ax[0].set_title("All")
53+
ax[1].hist(
54+
[values[i] for i in range(len(values)) if i not in self.filtered_indices],
55+
bins=100,
56+
)
57+
ax[1].set_title("Filtered")
58+
ax[2].hist([values[i] for i in self.filtered_indices], bins=100)
59+
ax[2].set_title("Excluded")
60+
fig.savefig(self.histogram, bbox_inches="tight")
61+
62+
@property
63+
def atoms(self):
64+
return [
65+
self.data[i] for i in range(len(self.data)) if i not in self.filtered_indices
66+
]
67+
68+
@property
69+
def excluded_atoms(self):
70+
return [self.data[i] for i in self.filtered_indices]

Diff for: ipsuite/nodes.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class _Nodes:
2626
)
2727
UniformTemporalSelection = "ipsuite.configuration_selection.UniformTemporalSelection"
2828
ThresholdSelection = "ipsuite.configuration_selection.ThresholdSelection"
29+
FilterOutlier = "ipsuite.configuration_selection.FilterOutlier"
2930
BatchKernelSelection = "ipsuite.models.apax.BatchKernelSelection"
3031

3132
# Configuration Comparison

Diff for: tests/integration/configuration_selection/test_index.py

+13
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,16 @@ def test_exclude_configurations_list(proj_path, traj_file):
137137
assert train_data[0].selected_configurations == {"AddData": [5, 6, 7, 8, 9]}
138138
assert test_data[0].selected_configurations == {"AddData": [0, 1, 2, 3, 4]}
139139
assert validation_data.selected_configurations == {"AddData": [10, 11, 12, 13, 14]}
140+
141+
142+
def test_filter_outlier(proj_path, traj_file):
143+
with ips.Project() as project:
144+
data = ips.AddData(file=traj_file)
145+
filtered_data = ips.configuration_selection.FilterOutlier(
146+
data=data.atoms, key="energy", threshold=1, direction="both"
147+
)
148+
149+
project.run()
150+
151+
filtered_data.load()
152+
assert len(filtered_data.atoms) == 13

0 commit comments

Comments
 (0)