Skip to content

Commit 21d8e65

Browse files
committed
Fix imputation
1 parent 70400e7 commit 21d8e65

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

analysis/compare_filters.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
from scipy.io import arff as arff_io
8-
from sklearn import preprocessing, metrics
8+
from sklearn import preprocessing, metrics, compose
99

1010

1111
from NoiseFiltersPy._filters import _implemented_filters
@@ -33,10 +33,15 @@ def calculate_filter_f1(dataset, filter, injector, rate = 0.1):
3333
if target.dtype == object:
3434
le.fit(target)
3535
target = le.transform(target)
36-
attrs = data.drop("class", axis = 1).values
37-
if not np.issubdtype(attrs.dtype, np.number):
38-
enc.fit(attrs)
39-
attrs = enc.transform(attrs).toarray()
36+
attrs = data.drop("class", axis = 1)
37+
if np.any(attrs.dtypes == object):
38+
ct = compose.ColumnTransformer(
39+
transformers = [("encoder", enc, attrs.dtypes == object)],
40+
remainder = "passthrough"
41+
)
42+
attrs = ct.fit_transform(attrs)
43+
44+
attrs = attrs.values
4045

4146
injector = injector(attrs, target, rate)
4247
injector.generate()

examples/aenn_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
classes = dataset.target
77
aenn = AENN()
88
filter = aenn(data, classes)
9-
print(filter.cleanData)
9+
print(filter.clean_data)

0 commit comments

Comments
 (0)