Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Faster and/or backend compatible ot.dist #701

Merged
merged 9 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

## Creators and Maintainers

This toolbox has been created by
This toolbox has been created by [Rémi Flamary](https://remi.flamary.com/)
and [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/).

* [Rémi Flamary](https://remi.flamary.com/)
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)

It is currently maintained by
It is currently maintained by :

* [Rémi Flamary](https://remi.flamary.com/)
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)
Expand Down
9 changes: 2 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,9 @@ The examples folder contain several examples and use case for the library. The f

## Acknowledgements

This toolbox has been created by
This toolbox has been created by [Rémi Flamary](https://remi.flamary.com/) and [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/).

* [Rémi Flamary](https://remi.flamary.com/)
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)

It is currently maintained by
It is currently maintained by :

* [Rémi Flamary](https://remi.flamary.com/)
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)
Expand All @@ -218,8 +215,6 @@ POT has benefited from the financing or manpower from the following partners:

<img src="https://pythonot.github.io/master/_static/images/logo_anr.jpg" alt="ANR" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_cnrs.jpg" alt="CNRS" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_3ia.jpg" alt="3IA" style="height:60px;"/><img src="https://pythonot.github.io/master/_static/images/logo_hiparis.png" alt="Hi!PARIS" style="height:60px;"/>



## Contributions and code of conduct

Every contribution is welcome and should respect the [contribution guidelines](https://pythonot.github.io/master/contributing.html). Each member of the project is expected to follow the [code of conduct](https://pythonot.github.io/master/code_of_conduct.html).
Expand Down
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- Added `ot.gaussian.bures_barycenter_gradient_descent` (PR #680)
- Added `ot.gaussian.bures_wasserstein_distance` (PR #680)
- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680)
- Backend implementation of `ot.dist` for (PR #701)

#### Closed issues
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
Expand Down
105 changes: 91 additions & 14 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@
from inspect import signature
from .backend import get_backend, Backend, NumpyBackend, JaxBackend

__time_tic_toc = time.time()
__time_tic_toc = time.perf_counter()


def tic():
r"""Python implementation of Matlab tic() function"""
global __time_tic_toc
__time_tic_toc = time.time()
__time_tic_toc = time.perf_counter()


def toc(message="Elapsed time : {} s"):
r"""Python implementation of Matlab toc() function"""
t = time.time()
t = time.perf_counter()
print(message.format(t - __time_tic_toc))
return t - __time_tic_toc


def toq():
r"""Python implementation of Julia toc() function"""
t = time.time()
t = time.perf_counter()
return t - __time_tic_toc


Expand Down Expand Up @@ -251,7 +251,7 @@
return a2, b2, M2


def euclidean_distances(X, Y, squared=False):
def euclidean_distances(X, Y, squared=False, nx=None):
r"""
Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the
distance matrix between each pair of vectors.
Expand All @@ -270,13 +270,13 @@
-------
distances : array-like, shape (`n_samples_1`, `n_samples_2`)
"""

nx = get_backend(X, Y)
if nx is None:
nx = get_backend(X, Y)

a2 = nx.einsum("ij,ij->i", X, X)
b2 = nx.einsum("ij,ij->i", Y, Y)

c = -2 * nx.dot(X, Y.T)
c = -2 * nx.dot(X, nx.transpose(Y))
c += a2[:, None]
c += b2[None, :]

Expand All @@ -291,11 +291,21 @@
return c


def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None):
def dist(
x1,
x2=None,
metric="sqeuclidean",
p=2,
w=None,
backend="auto",
nx=None,
use_tensor=False,
):
r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
from all compatible backends for the following metrics:
'sqeuclidean', 'euclidean', 'cityblock', 'minkowski', 'cosine', 'correlation'.

Parameters
----------
Expand All @@ -315,7 +325,17 @@
p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.
w : array-like, rank 1
Weights for the weighted metrics.

backend : str, optional
Backend to use for the computation. If 'auto', the backend is
automatically selected based on the input data. if 'scipy',
the ``scipy.spatial.distance.cdist`` function is used (and gradients are
detached).
use_tensor : bool, optional
If true use tensorized computation for the distance matrix which can
cause memory issues for large datasets. Default is False and the
parameter is used only for the 'cityblock' and 'minkowski' metrics.
nx : Backend, optional
Backend to perform computations on. If omitted, the backend defaults to that of `x1`.

Returns
-------
Expand All @@ -324,12 +344,69 @@
distance matrix computed with given metric

"""
if nx is None:
nx = get_backend(x1, x2)
if x2 is None:
x2 = x1
if metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True)
if backend == "scipy": # force scipy backend with cdist function
x1 = nx.to_numpy(x1)
x2 = nx.to_numpy(x2)
if isinstance(metric, str) and metric.endswith("minkowski"):
return nx.from_numpy(cdist(x1, x2, metric=metric, p=p, w=w))
if w is not None:
return nx.from_numpy(cdist(x1, x2, metric=metric, w=w))
return nx.from_numpy(cdist(x1, x2, metric=metric))

Check warning on line 358 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L352-L358

Added lines #L352 - L358 were not covered by tests
elif metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True, nx=nx)
elif metric == "euclidean":
return euclidean_distances(x1, x2, squared=False)
return euclidean_distances(x1, x2, squared=False, nx=nx)
elif metric == "cityblock":
if use_tensor:
return nx.sum(nx.abs(x1[:, None, :] - x2[None, :, :]), axis=2)
else:
M = 0.0
for i in range(x1.shape[1]):
M += nx.abs(x1[:, i][:, None] - x2[:, i][None, :])
return M
elif metric == "minkowski":
if w is None:
if use_tensor:
return nx.power(
nx.sum(
nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p), axis=2
),
1 / p,
)
else:
M = 0.0
for i in range(x1.shape[1]):
M += nx.abs(x1[:, i][:, None] - x2[:, i][None, :]) ** p
return M ** (1 / p)
else:
if use_tensor:
return nx.power(

Check warning on line 387 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L387

Added line #L387 was not covered by tests
nx.sum(
w[None, None, :]
* nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p),
axis=2,
),
1 / p,
)
else:
M = 0.0
for i in range(x1.shape[1]):
M += w[i] * nx.abs(x1[:, i][:, None] - x2[:, i][None, :]) ** p
return M ** (1 / p)
elif metric == "cosine":
nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1))
nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2))
return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :])
elif metric == "correlation":
x1 = x1 - nx.mean(x1, axis=1)[:, None]
x2 = x2 - nx.mean(x2, axis=1)[:, None]
nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1))
nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2))
return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :])
else:
if not get_backend(x1, x2).__name__ == "numpy":
raise NotImplementedError()
Expand Down
65 changes: 57 additions & 8 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,31 @@
import numpy as np
import sys
import pytest
import scipy

lst_metrics = [
"euclidean",
"sqeuclidean",
"cityblock",
"cosine",
"minkowski",
"correlation",
]

lst_all_metrics = lst_metrics + [
"braycurtis",
"canberra",
"chebyshev",
"dice",
"hamming",
"jaccard",
"matching",
"rogerstanimoto",
"russellrao",
"sokalmichener",
"sokalsneath",
"yule",
]


def get_LazyTensor(nx):
Expand Down Expand Up @@ -185,7 +210,7 @@ def test_dist():

assert D4[0, 1] == D4[1, 0]

# dist shoul return squared euclidean
# dist should return squared euclidean
np.testing.assert_allclose(D, D2, atol=1e-14)
np.testing.assert_allclose(D, D3, atol=1e-14)

Expand Down Expand Up @@ -229,21 +254,45 @@ def test_dist():
with pytest.raises(ValueError):
ot.dist(x, x, metric="wminkowski")

with pytest.raises(ValueError):
ot.dist(x, x, metric="fakeone")


def test_dist_backends(nx):
@pytest.mark.parametrize("metric", lst_metrics)
def test_dist_backends(nx, metric):
n = 100
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
x1 = nx.from_numpy(x)

lst_metric = ["euclidean", "sqeuclidean"]
# force numpy backend
D0 = ot.dist(x, x, metric=metric, backend="numpy")

# default backend
D = ot.dist(x, x, metric=metric)

# force nx arrays
D1 = ot.dist(x1, x1, metric=metric)

# low atol because jax forces float32
np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
np.testing.assert_allclose(D, D0, atol=1e-5)


@pytest.mark.parametrize("metric", lst_all_metrics)
def test_dist_vs_cdist(metric):
n = 10

rng = np.random.RandomState(0)
x = rng.randn(n, 2)
y = rng.randn(n + 1, 2)

for metric in lst_metric:
D = ot.dist(x, x, metric=metric)
D1 = ot.dist(x1, x1, metric=metric)
D = ot.dist(x, y, metric=metric)
Dt = ot.dist(x, y, metric=metric, use_tensor=True)
D2 = scipy.spatial.distance.cdist(x, y, metric=metric)

# low atol because jax forces float32
np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
np.testing.assert_allclose(D, D2, atol=1e-15)
np.testing.assert_allclose(D, Dt, atol=1e-15)


def test_dist0():
Expand Down
Loading