|
8 | 8 | from sklearn.datasets import (load_iris, make_classification, make_regression,
|
9 | 9 | make_spd_matrix)
|
10 | 10 | from numpy.testing import (assert_array_almost_equal, assert_array_equal,
|
11 |
| - assert_allclose) |
| 11 | + assert_allclose, assert_raises) |
12 | 12 | from sklearn.exceptions import ConvergenceWarning
|
13 | 13 | from sklearn.utils.validation import check_X_y
|
14 | 14 | from sklearn.preprocessing import StandardScaler
|
@@ -326,19 +326,34 @@ def test_large_output_iter(self):
|
326 | 326 | @pytest.mark.parametrize("basis", ("lda", "triplet_diffs"))
|
327 | 327 | def test_warm_start(self, basis):
|
328 | 328 | X, y = load_iris(return_X_y=True)
|
329 |
| - # Should work with warm_start=True even with first fit |
| 329 | + # Test that warm_start=True leads to different weights in each fit call |
330 | 330 | scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5,
|
331 | 331 | random_state=42, warm_start=True)
|
332 | 332 | scml.fit(X, y)
|
333 |
| - # Re-fitting should continue from previous fit |
334 |
| - before = class_separation(scml.transform(X), y) |
| 333 | + w_1 = scml.w_ |
| 334 | + avg_grad_w_1 = scml.avg_grad_w_ |
| 335 | + ada_grad_w_1 = scml.ada_grad_w_ |
335 | 336 | scml.fit(X, y)
|
336 |
| - # We used the whole same dataset, so it can led to overfitting |
337 |
| - after = class_separation(scml.transform(X), y) |
338 |
| - if basis == "lda": |
339 |
| - assert before > after # For lda, class separation improved with re-fit |
340 |
| - else: |
341 |
| - assert before < after # For triplet_diffs, it got worse |
| 337 | + w_2 = scml.w_ |
| 338 | + assert_raises(AssertionError, assert_array_almost_equal, w_1, w_2) |
| 339 | + # And that default warm_start value is False and leads to same |
| 340 | + # weights in each fit call |
| 341 | + scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5, |
| 342 | + random_state=42) |
| 343 | + scml.fit(X, y) |
| 344 | + w_3 = scml.w_ |
| 345 | + scml.fit(X, y) |
| 346 | + w_4 = scml.w_ |
| 347 | + assert_array_almost_equal(w_3, w_4) |
| 348 | + # But would lead to same results with warm_strat=True if same init params |
| 349 | + # were used |
| 350 | + scml.warm_start = True |
| 351 | + scml.w_ = w_1 |
| 352 | + scml.avg_grad_w_ = avg_grad_w_1 |
| 353 | + scml.ada_grad_w_ = ada_grad_w_1 |
| 354 | + scml.fit(X, y) |
| 355 | + w_5 = scml.w_ |
| 356 | + assert_array_almost_equal(w_2, w_5) |
342 | 357 |
|
343 | 358 | class TestLSML(MetricTestCase):
|
344 | 359 | def test_iris(self):
|
|
0 commit comments