Skip to content

Commit ac0a141

Browse files
committed
ctw fixes
1 parent f1594b2 commit ac0a141

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

tslearn/metrics/ctw.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
The :mod:`tslearn.metrics:ctc` module provides utilities related to
3+
Canonical Time Warping.
4+
"""
5+
16
import numpy as np
27
from sklearn.cross_decomposition import CCA
38

@@ -157,15 +162,14 @@ def ctw_path(
157162
# (possibly truncated to a fixed number of features) inputs
158163
seq1_tr = s1 @ be.eye(s1.shape[1], n_components, dtype=be.float64)
159164
seq2_tr = s2 @ be.eye(s2.shape[1], n_components, dtype=be.float64)
160-
current_path, score_match = dtw_path(
165+
current_path, current_score = dtw_path(
161166
seq1_tr,
162167
seq2_tr,
163168
global_constraint=global_constraint,
164169
sakoe_chiba_radius=sakoe_chiba_radius,
165170
itakura_max_slope=itakura_max_slope,
166171
be=be,
167172
)
168-
current_score = score_match
169173

170174
if verbose:
171175
print("Iteration 0, score={}".format(current_score))
@@ -176,7 +180,7 @@ def ctw_path(
176180
cca.fit(Wx @ s1, Wy @ s2)
177181
seq1_tr, seq2_tr = cca.transform(s1, s2)
178182

179-
current_path, score_match = dtw_path(
183+
new_path, new_score = dtw_path(
180184
seq1_tr,
181185
seq2_tr,
182186
global_constraint=global_constraint,
@@ -185,10 +189,10 @@ def ctw_path(
185189
be=be,
186190
)
187191

188-
if np.array_equal(current_path, current_path):
192+
if np.array_equal(current_path, new_path):
189193
break
190194

191-
current_score = score_match
195+
current_path, current_score = new_path, new_score
192196

193197
if verbose:
194198
print("Iteration {}, score={}".format(it + 1, current_score))
@@ -329,7 +333,7 @@ def cdist_ctw(
329333
If shape is (n_ts1, sz1), the dataset is composed of univariate time series.
330334
If shape is (sz1,), the dataset is composed of a unique univariate time series.
331335
dataset2 : None or array-like, shape=(n_ts2, sz2, d) or (n_ts2, sz2) or (sz2,) (default: None)
332-
Another dataset of time series.
336+
Another dataset of time series.
333337
If `None`, self-similarity of `dataset1` is returned.
334338
If shape is (n_ts2, sz2), the dataset is composed of univariate time series.
335339
If shape is (sz2,), the dataset is composed of a unique univariate time series.

tslearn/tests/test_metrics.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ def test_ctw():
7777
np.testing.assert_allclose(dist, 1.0)
7878
assert backend.belongs_to_backend(dist)
7979

80+
x = [[1, 1], [3, 4], [126, 126]]
81+
y = [[1, 1.], [3., 3], [4., 4], [2., 2], [0, 0], [127, 127]]
82+
dist_0 = tslearn.metrics.ctw_path(
83+
cast(x, array_type),
84+
cast(y, array_type),
85+
max_iter=2,
86+
be=be
87+
)[2]
88+
dist_1 = tslearn.metrics.ctw_path(
89+
cast(x, array_type),
90+
cast(y, array_type),
91+
max_iter=3,
92+
be=be
93+
)[2]
94+
assert dist_0 >= dist_1
95+
8096
# dtw
8197
n1, n2, d1, d2 = 15, 10, 3, 1
8298
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)