Skip to content

Commit fff1f31

Browse files
committed
[feat] add AENN use case
1 parent 7718c04 commit fff1f31

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

NoiseFiltersPy/AENN.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
@ref: [An Experiment with the Edited Nearest-Neighbor Rule](https://ieeexplore.ieee.org/ielx5/21/4309513/04309523.pdf)
3+
"""
4+
15
from sklearn.model_selection import KFold
26
from sklearn.neighbors import KNeighborsClassifier
37
import numpy as np
@@ -9,18 +13,30 @@
913
class AENN:
1014
def __init__(self, max_neighbours: int = 5, n_jobs: int = -1):
1115
self.max_neighbours = max_neighbours
12-
self.filter = Filter(parameters = {"max_neighbours": self.max_neighbours})
16+
self.filter = Filter(parameters={"max_neighbours": self.max_neighbours})
1317
self.n_jobs = n_jobs
1418

1519
def __call__(self, data: t.Sequence, classes: t.Sequence) -> Filter:
1620
self.isNoise = np.array([False] * len(classes))
1721
for n_neigh in range(1, self.max_neighbours + 1):
18-
self.clf = KNeighborsClassifier(n_neighbors = n_neigh, algorithm = 'kd_tree', n_jobs = self.n_jobs)
22+
self.clf = KNeighborsClassifier(n_neighbors=n_neigh, algorithm='kd_tree', n_jobs=self.n_jobs)
1923
for indx in np.argwhere(np.invert(self.isNoise)):
20-
self.clf.fit(np.delete(data, indx, axis = 0), np.delete(classes, indx, axis = 0))
24+
self.clf.fit(np.delete(data, indx, axis=0), np.delete(classes, indx, axis=0))
2125
pred = self.clf.predict(data[indx])
2226
self.isNoise[indx] = pred != classes[indx]
27+
print(f"n_neigh: {n_neigh}, is_noise count:, {sum(self.isNoise)}, total: {len(self.isNoise)}")
2328
self.filter.rem_indx = np.argwhere(self.isNoise)
2429
notNoise = np.invert(self.isNoise)
2530
self.filter.set_cleanData(data[notNoise], classes[notNoise])
2631
return self.filter
32+
33+
34+
if __name__ == "__main__":
35+
from sklearn.datasets import load_iris
36+
37+
data = load_iris()
38+
rm_feature_id = data.feature_names.index("petal length (cm)")
39+
features = data.data[:, [idx for idx in range(len(data.feature_names)) if idx != rm_feature_id]]
40+
labels = data.target
41+
filter = AENN()(features, labels)
42+
print(filter.rem_indx.shape, filter.rem_indx[:, 0])

0 commit comments

Comments
 (0)