Skip to content

Commit 928a67a

Browse files
authored
[MRG] Faster and/or backend compatible ot.dist (#701)
* better dist and tests * small stuff * change api * better tests and doc
1 parent cbdf979 commit 928a67a

File tree

5 files changed

+154
-34
lines changed

5 files changed

+154
-34
lines changed

CONTRIBUTORS.md

+3-5
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
## Creators and Maintainers
44

5-
This toolbox has been created by
5+
This toolbox has been created by [Rémi Flamary](https://remi.flamary.com/)
6+
and [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/).
67

7-
* [Rémi Flamary](https://remi.flamary.com/)
8-
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)
9-
10-
It is currently maintained by
8+
It is currently maintained by :
119

1210
* [Rémi Flamary](https://remi.flamary.com/)
1311
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)

README.md

+2-7
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,9 @@ The examples folder contain several examples and use case for the library. The f
202202

203203
## Acknowledgements
204204

205-
This toolbox has been created by
205+
This toolbox has been created by [Rémi Flamary](https://remi.flamary.com/) and [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/).
206206

207-
* [Rémi Flamary](https://remi.flamary.com/)
208-
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)
209-
210-
It is currently maintained by
207+
It is currently maintained by :
211208

212209
* [Rémi Flamary](https://remi.flamary.com/)
213210
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)
@@ -218,8 +215,6 @@ POT has benefited from the financing or manpower from the following partners:
218215

219216
<img src="https://pythonot.github.io/master/_static/images/logo_anr.jpg" alt="ANR" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_cnrs.jpg" alt="CNRS" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_3ia.jpg" alt="3IA" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_hiparis.png" alt="Hi!PARIS" style="height:60px;"/>
220217

221-
222-
223218
## Contributions and code of conduct
224219

225220
Every contribution is welcome and should respect the [contribution guidelines](https://pythonot.github.io/master/contributing.html). Each member of the project is expected to follow the [code of conduct](https://pythonot.github.io/master/code_of_conduct.html).

RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
- Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680)
1515
- Added `ot.gaussian.bures_wasserstein_distance` (PR #680)
1616
- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680)
17+
- Backend implementation of `ot.dist` for (PR #701)
1718

1819
#### Closed issues
1920
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

ot/utils.py

+91-14
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,25 @@
1717
from inspect import signature
1818
from .backend import get_backend, Backend, NumpyBackend, JaxBackend
1919

20-
__time_tic_toc = time.time()
20+
__time_tic_toc = time.perf_counter()
2121

2222

2323
def tic():
2424
r"""Python implementation of Matlab tic() function"""
2525
global __time_tic_toc
26-
__time_tic_toc = time.time()
26+
__time_tic_toc = time.perf_counter()
2727

2828

2929
def toc(message="Elapsed time : {} s"):
3030
r"""Python implementation of Matlab toc() function"""
31-
t = time.time()
31+
t = time.perf_counter()
3232
print(message.format(t - __time_tic_toc))
3333
return t - __time_tic_toc
3434

3535

3636
def toq():
3737
r"""Python implementation of Julia toc() function"""
38-
t = time.time()
38+
t = time.perf_counter()
3939
return t - __time_tic_toc
4040

4141

@@ -251,7 +251,7 @@ def clean_zeros(a, b, M):
251251
return a2, b2, M2
252252

253253

254-
def euclidean_distances(X, Y, squared=False):
254+
def euclidean_distances(X, Y, squared=False, nx=None):
255255
r"""
256256
Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the
257257
distance matrix between each pair of vectors.
@@ -270,13 +270,13 @@ def euclidean_distances(X, Y, squared=False):
270270
-------
271271
distances : array-like, shape (`n_samples_1`, `n_samples_2`)
272272
"""
273-
274-
nx = get_backend(X, Y)
273+
if nx is None:
274+
nx = get_backend(X, Y)
275275

276276
a2 = nx.einsum("ij,ij->i", X, X)
277277
b2 = nx.einsum("ij,ij->i", Y, Y)
278278

279-
c = -2 * nx.dot(X, Y.T)
279+
c = -2 * nx.dot(X, nx.transpose(Y))
280280
c += a2[:, None]
281281
c += b2[None, :]
282282

@@ -291,11 +291,21 @@ def euclidean_distances(X, Y, squared=False):
291291
return c
292292

293293

294-
def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
294+
def dist(
295+
x1,
296+
x2=None,
297+
metric="sqeuclidean",
298+
p=2,
299+
w=None,
300+
backend="auto",
301+
nx=None,
302+
use_tensor=False,
303+
):
295304
r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
296305
297306
.. note:: This function is backend-compatible and will work on arrays
298-
from all compatible backends.
307+
from all compatible backends for the following metrics:
308+
'sqeuclidean', 'euclidean', 'cityblock', 'minkowski', 'cosine', 'correlation'.
299309
300310
Parameters
301311
----------
@@ -315,7 +325,17 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
315325
p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
316326
w : array-like, rank 1
317327
Weights for the weighted metrics.
318-
328+
backend : str, optional
329+
Backend to use for the computation. If 'auto', the backend is
330+
automatically selected based on the input data. if 'scipy',
331+
the ``scipy.spatial.distance.cdist`` function is used (and gradients are
332+
detached).
333+
use_tensor : bool, optional
334+
If true use tensorized computation for the distance matrix which can
335+
cause memory issues for large datasets. Default is False and the
336+
parameter is used only for the 'cityblock' and 'minkowski' metrics.
337+
nx : Backend, optional
338+
Backend to perform computations on. If omitted, the backend defaults to that of `x1`.
319339
320340
Returns
321341
-------
@@ -324,12 +344,69 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
324344
distance matrix computed with given metric
325345
326346
"""
347+
if nx is None:
348+
nx = get_backend(x1, x2)
327349
if x2 is None:
328350
x2 = x1
329-
if metric == "sqeuclidean":
330-
return euclidean_distances(x1, x2, squared=True)
351+
if backend == "scipy": # force scipy backend with cdist function
352+
x1 = nx.to_numpy(x1)
353+
x2 = nx.to_numpy(x2)
354+
if isinstance(metric, str) and metric.endswith("minkowski"):
355+
return nx.from_numpy(cdist(x1, x2, metric=metric, p=p, w=w))
356+
if w is not None:
357+
return nx.from_numpy(cdist(x1, x2, metric=metric, w=w))
358+
return nx.from_numpy(cdist(x1, x2, metric=metric))
359+
elif metric == "sqeuclidean":
360+
return euclidean_distances(x1, x2, squared=True, nx=nx)
331361
elif metric == "euclidean":
332-
return euclidean_distances(x1, x2, squared=False)
362+
return euclidean_distances(x1, x2, squared=False, nx=nx)
363+
elif metric == "cityblock":
364+
if use_tensor:
365+
return nx.sum(nx.abs(x1[:, None, :] - x2[None, :, :]), axis=2)
366+
else:
367+
M = 0.0
368+
for i in range(x1.shape[1]):
369+
M += nx.abs(x1[:, i][:, None] - x2[:, i][None, :])
370+
return M
371+
elif metric == "minkowski":
372+
if w is None:
373+
if use_tensor:
374+
return nx.power(
375+
nx.sum(
376+
nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p), axis=2
377+
),
378+
1 / p,
379+
)
380+
else:
381+
M = 0.0
382+
for i in range(x1.shape[1]):
383+
M += nx.abs(x1[:, i][:, None] - x2[:, i][None, :]) ** p
384+
return M ** (1 / p)
385+
else:
386+
if use_tensor:
387+
return nx.power(
388+
nx.sum(
389+
w[None, None, :]
390+
* nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p),
391+
axis=2,
392+
),
393+
1 / p,
394+
)
395+
else:
396+
M = 0.0
397+
for i in range(x1.shape[1]):
398+
M += w[i] * nx.abs(x1[:, i][:, None] - x2[:, i][None, :]) ** p
399+
return M ** (1 / p)
400+
elif metric == "cosine":
401+
nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1))
402+
nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2))
403+
return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :])
404+
elif metric == "correlation":
405+
x1 = x1 - nx.mean(x1, axis=1)[:, None]
406+
x2 = x2 - nx.mean(x2, axis=1)[:, None]
407+
nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1))
408+
nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2))
409+
return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :])
333410
else:
334411
if not get_backend(x1, x2).__name__ == "numpy":
335412
raise NotImplementedError()

test/test_utils.py

+57-8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,31 @@
88
import numpy as np
99
import sys
1010
import pytest
11+
import scipy
12+
13+
lst_metrics = [
14+
"euclidean",
15+
"sqeuclidean",
16+
"cityblock",
17+
"cosine",
18+
"minkowski",
19+
"correlation",
20+
]
21+
22+
lst_all_metrics = lst_metrics + [
23+
"braycurtis",
24+
"canberra",
25+
"chebyshev",
26+
"dice",
27+
"hamming",
28+
"jaccard",
29+
"matching",
30+
"rogerstanimoto",
31+
"russellrao",
32+
"sokalmichener",
33+
"sokalsneath",
34+
"yule",
35+
]
1136

1237

1338
def get_LazyTensor(nx):
@@ -185,7 +210,7 @@ def test_dist():
185210

186211
assert D4[0, 1] == D4[1, 0]
187212

188-
# dist shoul return squared euclidean
213+
# dist should return squared euclidean
189214
np.testing.assert_allclose(D, D2, atol=1e-14)
190215
np.testing.assert_allclose(D, D3, atol=1e-14)
191216

@@ -229,21 +254,45 @@ def test_dist():
229254
with pytest.raises(ValueError):
230255
ot.dist(x, x, metric="wminkowski")
231256

257+
with pytest.raises(ValueError):
258+
ot.dist(x, x, metric="fakeone")
259+
232260

233-
def test_dist_backends(nx):
261+
@pytest.mark.parametrize("metric", lst_metrics)
262+
def test_dist_backends(nx, metric):
234263
n = 100
235264
rng = np.random.RandomState(0)
236265
x = rng.randn(n, 2)
237266
x1 = nx.from_numpy(x)
238267

239-
lst_metric = ["euclidean", "sqeuclidean"]
268+
# force numpy backend
269+
D0 = ot.dist(x, x, metric=metric, backend="numpy")
270+
271+
# default backend
272+
D = ot.dist(x, x, metric=metric)
273+
274+
# force nx arrays
275+
D1 = ot.dist(x1, x1, metric=metric)
276+
277+
# low atol because jax forces float32
278+
np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
279+
np.testing.assert_allclose(D, D0, atol=1e-5)
280+
281+
282+
@pytest.mark.parametrize("metric", lst_all_metrics)
283+
def test_dist_vs_cdist(metric):
284+
n = 10
285+
286+
rng = np.random.RandomState(0)
287+
x = rng.randn(n, 2)
288+
y = rng.randn(n + 1, 2)
240289

241-
for metric in lst_metric:
242-
D = ot.dist(x, x, metric=metric)
243-
D1 = ot.dist(x1, x1, metric=metric)
290+
D = ot.dist(x, y, metric=metric)
291+
Dt = ot.dist(x, y, metric=metric, use_tensor=True)
292+
D2 = scipy.spatial.distance.cdist(x, y, metric=metric)
244293

245-
# low atol because jax forces float32
246-
np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
294+
np.testing.assert_allclose(D, D2, atol=1e-15)
295+
np.testing.assert_allclose(D, Dt, atol=1e-15)
247296

248297

249298
def test_dist0():

0 commit comments

Comments
 (0)