|  | 
| 9 | 9 | from sklearn.datasets import (load_iris, make_classification, make_regression, | 
| 10 | 10 |                               make_spd_matrix) | 
| 11 | 11 | from numpy.testing import (assert_array_almost_equal, assert_array_equal, | 
| 12 |  | -                           assert_allclose) | 
|  | 12 | +                           assert_allclose, assert_raises) | 
| 13 | 13 | from sklearn.exceptions import ConvergenceWarning | 
| 14 | 14 | from sklearn.utils.validation import check_X_y | 
| 15 | 15 | from sklearn.preprocessing import StandardScaler | 
| @@ -327,19 +327,34 @@ def test_large_output_iter(self): | 
| 327 | 327 |   @pytest.mark.parametrize("basis", ("lda", "triplet_diffs")) | 
| 328 | 328 |   def test_warm_start(self, basis): | 
| 329 | 329 |     X, y = load_iris(return_X_y=True) | 
| 330 |  | -    # Should work with warm_start=True even with first fit | 
|  | 330 | +    # Test that warm_start=True leads to different weights in each fit call | 
| 331 | 331 |     scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5, | 
| 332 | 332 |                            random_state=42, warm_start=True) | 
| 333 | 333 |     scml.fit(X, y) | 
| 334 |  | -    # Re-fitting should continue from previous fit | 
| 335 |  | -    before = class_separation(scml.transform(X), y) | 
|  | 334 | +    w_1 = scml.w_ | 
|  | 335 | +    avg_grad_w_1 = scml.avg_grad_w_ | 
|  | 336 | +    ada_grad_w_1 = scml.ada_grad_w_ | 
| 336 | 337 |     scml.fit(X, y) | 
| 337 |  | -    # We used the whole same dataset, so it can led to overfitting | 
| 338 |  | -    after = class_separation(scml.transform(X), y) | 
| 339 |  | -    if basis == "lda": | 
| 340 |  | -      assert before > after  # For lda, class separation improved with re-fit | 
| 341 |  | -    else: | 
| 342 |  | -      assert before < after  # For triplet_diffs, it got worse | 
|  | 338 | +    w_2 = scml.w_ | 
|  | 339 | +    assert_raises(AssertionError, assert_array_almost_equal, w_1, w_2) | 
|  | 340 | +    # And that default warm_start value is False and leads to same | 
|  | 341 | +    # weights in each fit call | 
|  | 342 | +    scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5, | 
|  | 343 | +                           random_state=42) | 
|  | 344 | +    scml.fit(X, y) | 
|  | 345 | +    w_3 = scml.w_ | 
|  | 346 | +    scml.fit(X, y) | 
|  | 347 | +    w_4 = scml.w_ | 
|  | 348 | +    assert_array_almost_equal(w_3, w_4) | 
|  | 349 | +    # But would lead to same results with warm_strat=True if same init params | 
|  | 350 | +    # were used | 
|  | 351 | +    scml.warm_start = True | 
|  | 352 | +    scml.w_ = w_1 | 
|  | 353 | +    scml.avg_grad_w_ = avg_grad_w_1 | 
|  | 354 | +    scml.ada_grad_w_ = ada_grad_w_1 | 
|  | 355 | +    scml.fit(X, y) | 
|  | 356 | +    w_5 = scml.w_ | 
|  | 357 | +    assert_array_almost_equal(w_2, w_5) | 
| 343 | 358 | 
 | 
| 344 | 359 | class TestLSML(MetricTestCase): | 
| 345 | 360 |   def test_iris(self): | 
|  | 
0 commit comments