Skip to content

Commit 54fc3fb

Browse files
authored
Fixes #447: forward sklearn estimator attributes (#574)
1 parent 1895622 commit 54fc3fb

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

tslearn/svm/svm.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,27 @@
1313

1414

1515
class TimeSeriesSVMMixin:
16+
17+
@property
18+
def support_(self):
19+
check_is_fitted(self, ['svm_estimator_', '_X_fit'])
20+
return getattr(self, "svm_estimator_").support_
21+
22+
@property
23+
def dual_coef_(self):
24+
check_is_fitted(self, ['svm_estimator_', '_X_fit'])
25+
return getattr(self, "svm_estimator_").dual_coef_
26+
27+
@property
28+
def coef_(self):
29+
check_is_fitted(self, ['svm_estimator_', '_X_fit'])
30+
return getattr(self, "svm_estimator_").coef_
31+
32+
@property
33+
def intercept_(self):
34+
check_is_fitted(self, ['svm_estimator_', '_X_fit'])
35+
return getattr(self, "svm_estimator_").intercept_
36+
1637
def _preprocess_sklearn(self, X, y=None, fit_time=False):
1738
force_all_finite = self.kernel not in VARIABLE_LENGTH_METRICS
1839
if y is None:
@@ -446,9 +467,6 @@ class TimeSeriesSVR(TimeSeriesSVMMixin, RegressorMixin,
446467
intercept_ : array, shape = [1]
447468
Constants in decision function.
448469
449-
sample_weight : array-like, shape = [n_samples]
450-
Individual weights for each sample
451-
452470
svm_estimator_ : sklearn.svm.SVR
453471
The underlying sklearn estimator
454472

tslearn/tests/test_svm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import numpy as np
22

3+
import pytest
4+
5+
from sklearn.exceptions import NotFittedError
6+
37
from tslearn.metrics import cdist_gak
48
from tslearn.svm import TimeSeriesSVC, TimeSeriesSVR
59

610
__author__ = 'Romain Tavenard romain.tavenard[at]univ-rennes2.fr'
711

812

13+
914
def test_gamma_value_svm():
1015
n, sz, d = 5, 10, 3
1116
rng = np.random.RandomState(0)
@@ -22,3 +27,22 @@ def test_gamma_value_svm():
2227
cdist_mat = cdist_gak(time_series, sigma=np.sqrt(gamma / 2.))
2328

2429
np.testing.assert_allclose(sklearn_X, cdist_mat)
30+
31+
def test_attributes():
32+
n, sz, d = 5, 10, 3
33+
rng = np.random.RandomState(0)
34+
time_series = rng.randn(n, sz, d)
35+
labels = rng.randint(low=0, high=2, size=n)
36+
37+
for ModelClass in [TimeSeriesSVC, TimeSeriesSVR]:
38+
linear_model = ModelClass(kernel="linear")
39+
40+
for attr in ['coef_', 'support_', 'support_vectors_',
41+
'dual_coef_', 'coef_', 'intercept_']:
42+
with pytest.raises(NotFittedError):
43+
getattr(linear_model, attr)
44+
45+
linear_model.fit(time_series, labels)
46+
for attr in ['coef_', 'support_', 'support_vectors_',
47+
'dual_coef_', 'coef_', 'intercept_']:
48+
assert hasattr(linear_model, attr)

0 commit comments

Comments
 (0)