|
9 | 9 | from sklearn.datasets import (load_iris, make_classification, make_regression,
|
10 | 10 | make_spd_matrix)
|
11 | 11 | from numpy.testing import (assert_array_almost_equal, assert_array_equal,
|
12 |
| - assert_allclose) |
| 12 | + assert_allclose, assert_raises) |
13 | 13 | from sklearn.exceptions import ConvergenceWarning
|
14 | 14 | from sklearn.utils.validation import check_X_y
|
15 | 15 | from sklearn.preprocessing import StandardScaler
|
@@ -327,19 +327,34 @@ def test_large_output_iter(self):
|
327 | 327 | @pytest.mark.parametrize("basis", ("lda", "triplet_diffs"))
|
328 | 328 | def test_warm_start(self, basis):
|
329 | 329 | X, y = load_iris(return_X_y=True)
|
330 |
| - # Should work with warm_start=True even with first fit |
| 330 | + # Test that warm_start=True leads to different weights in each fit call |
331 | 331 | scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5,
|
332 | 332 | random_state=42, warm_start=True)
|
333 | 333 | scml.fit(X, y)
|
334 |
| - # Re-fitting should continue from previous fit |
335 |
| - before = class_separation(scml.transform(X), y) |
| 334 | + w_1 = scml.w_ |
| 335 | + avg_grad_w_1 = scml.avg_grad_w_ |
| 336 | + ada_grad_w_1 = scml.ada_grad_w_ |
336 | 337 | scml.fit(X, y)
|
337 |
| - # We used the whole same dataset, so it can led to overfitting |
338 |
| - after = class_separation(scml.transform(X), y) |
339 |
| - if basis == "lda": |
340 |
| - assert before > after # For lda, class separation improved with re-fit |
341 |
| - else: |
342 |
| - assert before < after # For triplet_diffs, it got worse |
| 338 | + w_2 = scml.w_ |
| 339 | + assert_raises(AssertionError, assert_array_almost_equal, w_1, w_2) |
| 340 | + # And that default warm_start value is False and leads to same |
| 341 | + # weights in each fit call |
| 342 | + scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5, |
| 343 | + random_state=42) |
| 344 | + scml.fit(X, y) |
| 345 | + w_3 = scml.w_ |
| 346 | + scml.fit(X, y) |
| 347 | + w_4 = scml.w_ |
| 348 | + assert_array_almost_equal(w_3, w_4) |
| 349 | + # But would lead to same results with warm_strat=True if same init params |
| 350 | + # were used |
| 351 | + scml.warm_start = True |
| 352 | + scml.w_ = w_1 |
| 353 | + scml.avg_grad_w_ = avg_grad_w_1 |
| 354 | + scml.ada_grad_w_ = ada_grad_w_1 |
| 355 | + scml.fit(X, y) |
| 356 | + w_5 = scml.w_ |
| 357 | + assert_array_almost_equal(w_2, w_5) |
343 | 358 |
|
344 | 359 | class TestLSML(MetricTestCase):
|
345 | 360 | def test_iris(self):
|
|
0 commit comments