Skip to content

Commit 27498f9

Browse files
committed
add tests to select features
1 parent 15eab96 commit 27498f9

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

tests/test_functionality.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
StandardNormalVariate,
1919
)
2020
from chemotools.smooth import MeanFilter, MedianFilter, WhittakerSmooth
21-
from chemotools.variable_selection import RangeCut
21+
from chemotools.variable_selection import RangeCut, SelectFeatures
2222
from tests.fixtures import (
2323
spectrum,
2424
spectrum_arpls,
@@ -439,8 +439,8 @@ def test_point_scaler(spectrum):
439439

440440
def test_point_scaler_with_wavenumbers():
441441
# Arrange
442-
wavenumbers = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
443-
spectrum = np.array([[10., 12., 14., 16., 14., 12., 10., 12., 14., 16.]])
442+
wavenumbers = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
443+
spectrum = np.array([[10.0, 12.0, 14.0, 16.0, 14.0, 12.0, 10.0, 12.0, 14.0, 16.0]])
444444

445445
# Act
446446
index_scaler = PointScaler(point=4, wavenumbers=wavenumbers)
@@ -450,7 +450,6 @@ def test_point_scaler_with_wavenumbers():
450450
assert np.allclose(spectrum_corrected[0], spectrum[0] / spectrum[0][3], atol=1e-8)
451451

452452

453-
454453
def test_range_cut_by_index(spectrum):
455454
# Arrange
456455
range_cut = RangeCut(start=0, end=10)
@@ -544,6 +543,35 @@ def test_saviszky_golay_filter_3():
544543
assert np.allclose(spectrum_corrected[0], np.ones((1, 10)), atol=1e-2)
545544

546545

546+
def test_select_features():
547+
# Arrange
548+
spectrum = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
549+
expected = np.array([[1, 2, 3, 8, 9, 10]])
550+
551+
# Act
552+
select_features = SelectFeatures(features=np.array([0, 1, 2, 7, 8, 9]))
553+
spectrum_corrected = select_features.fit_transform(spectrum)
554+
555+
# Assert
556+
assert np.allclose(spectrum_corrected[0], expected, atol=1e-8)
557+
558+
559+
def test_select_features_with_wavenumbers():
560+
# Arrange
561+
wavenumbers = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
562+
spectrum = np.array([[1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0, 55.0, 89.0]])
563+
expected = np.array([[1.0, 2.0, 3.0, 34.0, 55.0, 89.0]])
564+
565+
# Act
566+
select_features = SelectFeatures(
567+
features=np.array([1, 2, 3, 8, 9, 10]), wavenumbers=wavenumbers
568+
)
569+
spectrum_corrected = select_features.fit_transform(spectrum)
570+
571+
# Assert
572+
assert np.allclose(spectrum_corrected[0], expected, atol=1e-8)
573+
574+
547575
def test_standard_normal_variate(spectrum, reference_snv):
548576
# Arrange
549577
snv = StandardNormalVariate()

tests/test_sklearn_compliance.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
SavitzkyGolayFilter,
2525
WhittakerSmooth,
2626
)
27-
from chemotools.variable_selection import RangeCut
27+
from chemotools.variable_selection import RangeCut, SelectFeatures
2828

2929
from tests.fixtures import spectrum
3030

@@ -173,6 +173,14 @@ def test_compliance_savitzky_golay_filter():
173173
check_estimator(transformer)
174174

175175

176+
# SelectFeatures
177+
def test_compliance_select_features():
178+
# Arrange
179+
transformer = SelectFeatures()
180+
# Act & Assert
181+
check_estimator(transformer)
182+
183+
176184
# StandardNormalVariate
177185
def test_compliance_standard_normal_variate():
178186
# Arrange

0 commit comments

Comments
 (0)