Skip to content

Commit 9043960

Browse files
[MRG] Replaced coo_matrix with coo_array better compatability and added tes… (#782)
* Replaced coo_matrix with coo_array better compatability and added test to test coo_array functionnality * Updated release file * Replaced some more coo_matrix calls --------- Co-authored-by: Rémi Flamary <[email protected]>
1 parent 5de4614 commit 9043960

File tree

4 files changed

+24
-16
lines changed

4 files changed

+24
-16
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ This new release adds support for sparse cost matrices in the exact EMD solver.
66

77
#### New features
88
- Add support for sparse cost matrices in exact EMD solver `ot.emd` and `ot.emd2` (PR #778)
9+
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` API (PR #TBD)
910

1011
#### Closed issues
1112
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)

ot/backend.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
import scipy
9595
import scipy.linalg
9696
import scipy.special as special
97-
from scipy.sparse import coo_matrix, csr_matrix, issparse
97+
from scipy.sparse import coo_array, csr_matrix, issparse
9898

9999
DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH"
100100
DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX"
@@ -802,9 +802,9 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
802802
r"""
803803
Creates a sparse tensor in COOrdinate format.
804804
805-
This function follows the api from :any:`scipy.sparse.coo_matrix`
805+
This function follows the api from :any:`scipy.sparse.coo_array`
806806
807-
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
807+
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_array.html
808808
"""
809809
raise NotImplementedError()
810810

@@ -1354,9 +1354,9 @@ def randperm(self, size, type_as=None):
13541354

13551355
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
13561356
if type_as is None:
1357-
return coo_matrix((data, (rows, cols)), shape=shape)
1357+
return coo_array((data, (rows, cols)), shape=shape)
13581358
else:
1359-
return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
1359+
return coo_array((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
13601360

13611361
def issparse(self, a):
13621362
return issparse(a)
@@ -1384,9 +1384,9 @@ def todense(self, a):
13841384
return a
13851385

13861386
def sparse_coo_data(self, a):
1387-
# Convert to COO format if needed
1388-
if not isinstance(a, coo_matrix):
1389-
a_coo = coo_matrix(a)
1387+
# Convert to COO array format if needed
1388+
if not isinstance(a, coo_array):
1389+
a_coo = coo_array(a)
13901390
else:
13911391
a_coo = a
13921392

@@ -1815,9 +1815,7 @@ def sparse_coo_data(self, a):
18151815
# JAX doesn't support sparse matrices, so this shouldn't be called
18161816
# But if it is, convert the dense array to sparse using scipy
18171817
a_np = self.to_numpy(a)
1818-
from scipy.sparse import coo_matrix
1819-
1820-
a_coo = coo_matrix(a_np)
1818+
a_coo = coo_array(a_np)
18211819
return a_coo.row, a_coo.col, a_coo.data, a_coo.shape
18221820

18231821
def where(self, condition, x=None, y=None):
@@ -2804,10 +2802,10 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
28042802
rows = self.from_numpy(rows)
28052803
cols = self.from_numpy(cols)
28062804
if type_as is None:
2807-
return cupyx.scipy.sparse.coo_matrix((data, (rows, cols)), shape=shape)
2805+
return cupyx.scipy.sparse.coo_array((data, (rows, cols)), shape=shape)
28082806
else:
28092807
with cp.cuda.Device(type_as.device):
2810-
return cupyx.scipy.sparse.coo_matrix(
2808+
return cupyx.scipy.sparse.coo_array(
28112809
(data, (rows, cols)), shape=shape, dtype=type_as.dtype
28122810
)
28132811

ot/plot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import numpy as np
1616
import matplotlib.pylab as pl
1717
from matplotlib import gridspec
18+
from . import backend
19+
from scipy.sparse import issparse, coo_array
1820

1921

2022
def plot1D_mat(
@@ -232,8 +234,6 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
232234
parameters given to the plot functions (default color is black if
233235
nothing given)
234236
"""
235-
from . import backend
236-
from scipy.sparse import issparse, coo_matrix
237237

238238
if ("color" not in kwargs) and ("c" not in kwargs):
239239
kwargs["color"] = "k"
@@ -258,7 +258,7 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
258258
# Not a backend array, check if scipy.sparse
259259
is_sparse = issparse(G)
260260
if is_sparse:
261-
G_coo = G if isinstance(G, coo_matrix) else G.tocoo()
261+
G_coo = G if isinstance(G, coo_array) else G.tocoo()
262262
rows, cols, data = G_coo.row, G_coo.col, G_coo.data
263263

264264
if is_sparse:

test/test_ot.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,15 @@ def test_emd_sparse_vs_dense(nx):
992992
b, nx.to_numpy(nx.sum(G_sparse_dense, 0)), rtol=1e-5, atol=1e-7
993993
)
994994

995+
# Test coo_array element-wise multiplication (only works with coo_array, not coo_matrix)
996+
if nx.__name__ == "numpy":
997+
# This tests that we're using coo_array which supports element-wise operations
998+
M_sparse_np = M_sparse
999+
G_sparse_np = G_sparse
1000+
loss_sparse = np.sum(G_sparse_np * M_sparse_np)
1001+
# Verify the loss calculation is reasonable
1002+
assert loss_sparse >= 0, "Sparse loss should be non-negative"
1003+
9951004

9961005
def test_emd2_sparse_vs_dense(nx):
9971006
"""Test that sparse and dense emd2 solvers produce identical costs.

0 commit comments

Comments
 (0)