Skip to content

Commit c732507

Browse files
author
Maximiliano Marufo da Silva
committed
Updates test_warm_start with a better design
1 parent 4a4c657 commit c732507

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

test/metric_learn_test.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.datasets import (load_iris, make_classification, make_regression,
1010
make_spd_matrix)
1111
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
12-
assert_allclose)
12+
assert_allclose, assert_raises)
1313
from sklearn.exceptions import ConvergenceWarning
1414
from sklearn.utils.validation import check_X_y
1515
from sklearn.preprocessing import StandardScaler
@@ -327,19 +327,34 @@ def test_large_output_iter(self):
327327
@pytest.mark.parametrize("basis", ("lda", "triplet_diffs"))
328328
def test_warm_start(self, basis):
329329
X, y = load_iris(return_X_y=True)
330-
# Should work with warm_start=True even with first fit
330+
# Test that warm_start=True leads to different weights in each fit call
331331
scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5,
332332
random_state=42, warm_start=True)
333333
scml.fit(X, y)
334-
# Re-fitting should continue from previous fit
335-
before = class_separation(scml.transform(X), y)
334+
w_1 = scml.w_
335+
avg_grad_w_1 = scml.avg_grad_w_
336+
ada_grad_w_1 = scml.ada_grad_w_
336337
scml.fit(X, y)
337-
# We used the whole same dataset, so it can led to overfitting
338-
after = class_separation(scml.transform(X), y)
339-
if basis == "lda":
340-
assert before > after # For lda, class separation improved with re-fit
341-
else:
342-
assert before < after # For triplet_diffs, it got worse
338+
w_2 = scml.w_
339+
assert_raises(AssertionError, assert_array_almost_equal, w_1, w_2)
340+
# And that default warm_start value is False and leads to same
341+
# weights in each fit call
342+
scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5,
343+
random_state=42)
344+
scml.fit(X, y)
345+
w_3 = scml.w_
346+
scml.fit(X, y)
347+
w_4 = scml.w_
348+
assert_array_almost_equal(w_3, w_4)
349+
# But would lead to same results with warm_strat=True if same init params
350+
# were used
351+
scml.warm_start = True
352+
scml.w_ = w_1
353+
scml.avg_grad_w_ = avg_grad_w_1
354+
scml.ada_grad_w_ = ada_grad_w_1
355+
scml.fit(X, y)
356+
w_5 = scml.w_
357+
assert_array_almost_equal(w_2, w_5)
343358

344359
class TestLSML(MetricTestCase):
345360
def test_iris(self):

0 commit comments

Comments
 (0)