Skip to content

Commit 0c58991

Browse files
authored
[MRG] Distance calculation bug solve (#306)
* solve bug * Weights & docs * tests for dist * test dist * pep8
1 parent f162879 commit 0c58991

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

ot/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def euclidean_distances(X, Y, squared=False):
182182
return c
183183

184184

185-
def dist(x1, x2=None, metric='sqeuclidean', p=2):
185+
def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None):
186186
r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
187187
188188
.. note:: This function is backend-compatible and will work on arrays
@@ -202,6 +202,10 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
202202
'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
203203
'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
204204
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
205+
p : float, optional
206+
p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
207+
w : array-like, rank 1
208+
Weights for the weighted metrics.
205209
206210
207211
Returns
@@ -221,7 +225,9 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
221225
if not get_backend(x1, x2).__name__ == 'numpy':
222226
raise NotImplementedError()
223227
else:
224-
return cdist(x1, x2, metric=metric, p=p)
228+
if metric.endswith("minkowski"):
229+
return cdist(x1, x2, metric=metric, p=p, w=w)
230+
return cdist(x1, x2, metric=metric, w=w)
225231

226232

227233
def dist0(n, method='lin_square'):

test/test_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,26 @@ def test_dist():
117117
np.testing.assert_allclose(D, D2, atol=1e-14)
118118
np.testing.assert_allclose(D, D3, atol=1e-14)
119119

120+
# tests that every metric runs correctly
121+
metrics_w = [
122+
'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
123+
'euclidean', 'hamming', 'jaccard', 'kulsinski',
124+
'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
125+
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'
126+
] # those that support weights
127+
metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version
128+
129+
for metric in metrics_w:
130+
print(metric)
131+
ot.dist(x, x, metric=metric, p=3, w=np.random.random((2, )))
132+
for metric in metrics:
133+
print(metric)
134+
ot.dist(x, x, metric=metric, p=3)
135+
136+
# weighted minkowski but with no weights
137+
with pytest.raises(ValueError):
138+
ot.dist(x, x, metric="wminkowski")
139+
120140

121141
def test_dist_backends(nx):
122142

0 commit comments

Comments
 (0)