Skip to content

Commit 7329927

Browse files
author
Maximiliano Marufo da Silva
committed
Updates test_warm_start with a better design
1 parent 8248c84 commit 7329927

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

test/metric_learn_test.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.datasets import (load_iris, make_classification, make_regression,
99
make_spd_matrix)
1010
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
11-
assert_allclose)
11+
assert_allclose, assert_raises)
1212
from sklearn.exceptions import ConvergenceWarning
1313
from sklearn.utils.validation import check_X_y
1414
from sklearn.preprocessing import StandardScaler
@@ -326,19 +326,34 @@ def test_large_output_iter(self):
326326
@pytest.mark.parametrize("basis", ("lda", "triplet_diffs"))
327327
def test_warm_start(self, basis):
328328
X, y = load_iris(return_X_y=True)
329-
# Should work with warm_start=True even with first fit
329+
# Test that warm_start=True leads to different weights in each fit call
330330
scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5,
331331
random_state=42, warm_start=True)
332332
scml.fit(X, y)
333-
# Re-fitting should continue from previous fit
334-
before = class_separation(scml.transform(X), y)
333+
w_1 = scml.w_
334+
avg_grad_w_1 = scml.avg_grad_w_
335+
ada_grad_w_1 = scml.ada_grad_w_
335336
scml.fit(X, y)
336-
# We used the whole same dataset, so it can led to overfitting
337-
after = class_separation(scml.transform(X), y)
338-
if basis == "lda":
339-
assert before > after # For lda, class separation improved with re-fit
340-
else:
341-
assert before < after # For triplet_diffs, it got worse
337+
w_2 = scml.w_
338+
assert_raises(AssertionError, assert_array_almost_equal, w_1, w_2)
339+
# And that default warm_start value is False and leads to same
340+
# weights in each fit call
341+
scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5,
342+
random_state=42)
343+
scml.fit(X, y)
344+
w_3 = scml.w_
345+
scml.fit(X, y)
346+
w_4 = scml.w_
347+
assert_array_almost_equal(w_3, w_4)
348+
# But would lead to same results with warm_strat=True if same init params
349+
# were used
350+
scml.warm_start = True
351+
scml.w_ = w_1
352+
scml.avg_grad_w_ = avg_grad_w_1
353+
scml.ada_grad_w_ = ada_grad_w_1
354+
scml.fit(X, y)
355+
w_5 = scml.w_
356+
assert_array_almost_equal(w_2, w_5)
342357

343358
class TestLSML(MetricTestCase):
344359
def test_iris(self):

0 commit comments

Comments
 (0)