Skip to content

Commit ca69658

Browse files
authored
[MRG] Cupy backend (#315)
* Cupy backend * pep8 + bug * working even if cupy not installed * attempt to force codecov to ignore cupy because no gpu can be used for testing on github * docstring * readme
1 parent cb51064 commit ca69658

File tree

7 files changed

+355
-26
lines changed

7 files changed

+355
-26
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ The contributors to this library are
196196
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
197197
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
198198
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
199+
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
199200

200201
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
201202

ot/backend.py

Lines changed: 300 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Multi-lib backend for POT
44
55
The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
6-
or Jax, POT code should work nonetheless.
6+
Jax, or Cupy, POT code should work nonetheless.
77
To achieve that, POT provides backend classes which implements functions in their respective backend
88
imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
99
@@ -44,6 +44,14 @@
4444
jax = False
4545
jax_type = float
4646

47+
try:
48+
import cupy as cp
49+
import cupyx
50+
cp_type = cp.ndarray
51+
except ImportError:
52+
cp = False
53+
cp_type = float
54+
4755
str_type_error = "All array should be from the same type/backend. Current types are : {}"
4856

4957

@@ -57,6 +65,9 @@ def get_backend_list():
5765
if jax:
5866
lst.append(JaxBackend())
5967

68+
if cp:
69+
lst.append(CupyBackend())
70+
6071
return lst
6172

6273

@@ -78,6 +89,8 @@ def get_backend(*args):
7889
return TorchBackend()
7990
elif isinstance(args[0], jax_type):
8091
return JaxBackend()
92+
elif isinstance(args[0], cp_type):
93+
return CupyBackend()
8194
else:
8295
raise ValueError("Unknown type of non implemented backend.")
8396

@@ -94,7 +107,8 @@ def to_numpy(*args):
94107
class Backend():
95108
"""
96109
Backend abstract class.
97-
Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
110+
Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`,
111+
:py:class:`CupyBackend`
98112
99113
- The `__name__` class attribute refers to the name of the backend.
100114
- The `__type__` class attribute refers to the data structure used by the backend.
@@ -1500,3 +1514,287 @@ def assert_same_dtype_device(self, a, b):
15001514

15011515
assert a_dtype == b_dtype, "Dtype discrepancy"
15021516
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
1517+
1518+
1519+
class CupyBackend(Backend): # pragma: no cover
1520+
"""
1521+
CuPy implementation of the backend
1522+
1523+
- `__name__` is "cupy"
1524+
- `__type__` is cp.ndarray
1525+
"""
1526+
1527+
__name__ = 'cupy'
1528+
__type__ = cp_type
1529+
__type_list__ = None
1530+
1531+
rng_ = None
1532+
1533+
def __init__(self):
1534+
self.rng_ = cp.random.RandomState()
1535+
1536+
self.__type_list__ = [
1537+
cp.array(1, dtype=cp.float32),
1538+
cp.array(1, dtype=cp.float64)
1539+
]
1540+
1541+
def to_numpy(self, a):
1542+
return cp.asnumpy(a)
1543+
1544+
def from_numpy(self, a, type_as=None):
1545+
if type_as is None:
1546+
return cp.asarray(a)
1547+
else:
1548+
with cp.cuda.Device(type_as.device):
1549+
return cp.asarray(a, dtype=type_as.dtype)
1550+
1551+
def set_gradients(self, val, inputs, grads):
1552+
# No gradients for cupy
1553+
return val
1554+
1555+
def zeros(self, shape, type_as=None):
1556+
if isinstance(shape, (list, tuple)):
1557+
shape = tuple(int(i) for i in shape)
1558+
if type_as is None:
1559+
return cp.zeros(shape)
1560+
else:
1561+
with cp.cuda.Device(type_as.device):
1562+
return cp.zeros(shape, dtype=type_as.dtype)
1563+
1564+
def ones(self, shape, type_as=None):
1565+
if isinstance(shape, (list, tuple)):
1566+
shape = tuple(int(i) for i in shape)
1567+
if type_as is None:
1568+
return cp.ones(shape)
1569+
else:
1570+
with cp.cuda.Device(type_as.device):
1571+
return cp.ones(shape, dtype=type_as.dtype)
1572+
1573+
def arange(self, stop, start=0, step=1, type_as=None):
1574+
return cp.arange(start, stop, step)
1575+
1576+
def full(self, shape, fill_value, type_as=None):
1577+
if isinstance(shape, (list, tuple)):
1578+
shape = tuple(int(i) for i in shape)
1579+
if type_as is None:
1580+
return cp.full(shape, fill_value)
1581+
else:
1582+
with cp.cuda.Device(type_as.device):
1583+
return cp.full(shape, fill_value, dtype=type_as.dtype)
1584+
1585+
def eye(self, N, M=None, type_as=None):
1586+
if type_as is None:
1587+
return cp.eye(N, M)
1588+
else:
1589+
with cp.cuda.Device(type_as.device):
1590+
return cp.eye(N, M, dtype=type_as.dtype)
1591+
1592+
def sum(self, a, axis=None, keepdims=False):
1593+
return cp.sum(a, axis, keepdims=keepdims)
1594+
1595+
def cumsum(self, a, axis=None):
1596+
return cp.cumsum(a, axis)
1597+
1598+
def max(self, a, axis=None, keepdims=False):
1599+
return cp.max(a, axis, keepdims=keepdims)
1600+
1601+
def min(self, a, axis=None, keepdims=False):
1602+
return cp.min(a, axis, keepdims=keepdims)
1603+
1604+
def maximum(self, a, b):
1605+
return cp.maximum(a, b)
1606+
1607+
def minimum(self, a, b):
1608+
return cp.minimum(a, b)
1609+
1610+
def abs(self, a):
1611+
return cp.abs(a)
1612+
1613+
def exp(self, a):
1614+
return cp.exp(a)
1615+
1616+
def log(self, a):
1617+
return cp.log(a)
1618+
1619+
def sqrt(self, a):
1620+
return cp.sqrt(a)
1621+
1622+
def power(self, a, exponents):
1623+
return cp.power(a, exponents)
1624+
1625+
def dot(self, a, b):
1626+
return cp.dot(a, b)
1627+
1628+
def norm(self, a):
1629+
return cp.sqrt(cp.sum(cp.square(a)))
1630+
1631+
def any(self, a):
1632+
return cp.any(a)
1633+
1634+
def isnan(self, a):
1635+
return cp.isnan(a)
1636+
1637+
def isinf(self, a):
1638+
return cp.isinf(a)
1639+
1640+
def einsum(self, subscripts, *operands):
1641+
return cp.einsum(subscripts, *operands)
1642+
1643+
def sort(self, a, axis=-1):
1644+
return cp.sort(a, axis)
1645+
1646+
def argsort(self, a, axis=-1):
1647+
return cp.argsort(a, axis)
1648+
1649+
def searchsorted(self, a, v, side='left'):
1650+
if a.ndim == 1:
1651+
return cp.searchsorted(a, v, side)
1652+
else:
1653+
# this is a not very efficient way to make numpy
1654+
# searchsorted work on 2d arrays
1655+
ret = cp.empty(v.shape, dtype=int)
1656+
for i in range(a.shape[0]):
1657+
ret[i, :] = cp.searchsorted(a[i, :], v[i, :], side)
1658+
return ret
1659+
1660+
def flip(self, a, axis=None):
1661+
return cp.flip(a, axis)
1662+
1663+
def outer(self, a, b):
1664+
return cp.outer(a, b)
1665+
1666+
def clip(self, a, a_min, a_max):
1667+
return cp.clip(a, a_min, a_max)
1668+
1669+
def repeat(self, a, repeats, axis=None):
1670+
return cp.repeat(a, repeats, axis)
1671+
1672+
def take_along_axis(self, arr, indices, axis):
1673+
return cp.take_along_axis(arr, indices, axis)
1674+
1675+
def concatenate(self, arrays, axis=0):
1676+
return cp.concatenate(arrays, axis)
1677+
1678+
def zero_pad(self, a, pad_width):
1679+
return cp.pad(a, pad_width)
1680+
1681+
def argmax(self, a, axis=None):
1682+
return cp.argmax(a, axis=axis)
1683+
1684+
def mean(self, a, axis=None):
1685+
return cp.mean(a, axis=axis)
1686+
1687+
def std(self, a, axis=None):
1688+
return cp.std(a, axis=axis)
1689+
1690+
def linspace(self, start, stop, num):
1691+
return cp.linspace(start, stop, num)
1692+
1693+
def meshgrid(self, a, b):
1694+
return cp.meshgrid(a, b)
1695+
1696+
def diag(self, a, k=0):
1697+
return cp.diag(a, k)
1698+
1699+
def unique(self, a):
1700+
return cp.unique(a)
1701+
1702+
def logsumexp(self, a, axis=None):
1703+
# Taken from
1704+
# https://github.com/scipy/scipy/blob/v1.7.1/scipy/special/_logsumexp.py#L7-L127
1705+
a_max = cp.amax(a, axis=axis, keepdims=True)
1706+
1707+
if a_max.ndim > 0:
1708+
a_max[~cp.isfinite(a_max)] = 0
1709+
elif not cp.isfinite(a_max):
1710+
a_max = 0
1711+
1712+
tmp = cp.exp(a - a_max)
1713+
s = cp.sum(tmp, axis=axis)
1714+
out = cp.log(s)
1715+
a_max = cp.squeeze(a_max, axis=axis)
1716+
out += a_max
1717+
return out
1718+
1719+
def stack(self, arrays, axis=0):
1720+
return cp.stack(arrays, axis)
1721+
1722+
def reshape(self, a, shape):
1723+
return cp.reshape(a, shape)
1724+
1725+
def seed(self, seed=None):
1726+
if seed is not None:
1727+
self.rng_.seed(seed)
1728+
1729+
def rand(self, *size, type_as=None):
1730+
if type_as is None:
1731+
return self.rng_.rand(*size)
1732+
else:
1733+
with cp.cuda.Device(type_as.device):
1734+
return self.rng_.rand(*size, dtype=type_as.dtype)
1735+
1736+
def randn(self, *size, type_as=None):
1737+
if type_as is None:
1738+
return self.rng_.randn(*size)
1739+
else:
1740+
with cp.cuda.Device(type_as.device):
1741+
return self.rng_.randn(*size, dtype=type_as.dtype)
1742+
1743+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
1744+
data = self.from_numpy(data)
1745+
rows = self.from_numpy(rows)
1746+
cols = self.from_numpy(cols)
1747+
if type_as is None:
1748+
return cupyx.scipy.sparse.coo_matrix(
1749+
(data, (rows, cols)), shape=shape
1750+
)
1751+
else:
1752+
with cp.cuda.Device(type_as.device):
1753+
return cupyx.scipy.sparse.coo_matrix(
1754+
(data, (rows, cols)), shape=shape, dtype=type_as.dtype
1755+
)
1756+
1757+
def issparse(self, a):
1758+
return cupyx.scipy.sparse.issparse(a)
1759+
1760+
def tocsr(self, a):
1761+
if self.issparse(a):
1762+
return a.tocsr()
1763+
else:
1764+
return cupyx.scipy.sparse.csr_matrix(a)
1765+
1766+
def eliminate_zeros(self, a, threshold=0.):
1767+
if threshold > 0:
1768+
if self.issparse(a):
1769+
a.data[self.abs(a.data) <= threshold] = 0
1770+
else:
1771+
a[self.abs(a) <= threshold] = 0
1772+
if self.issparse(a):
1773+
a.eliminate_zeros()
1774+
return a
1775+
1776+
def todense(self, a):
1777+
if self.issparse(a):
1778+
return a.toarray()
1779+
else:
1780+
return a
1781+
1782+
def where(self, condition, x, y):
1783+
return cp.where(condition, x, y)
1784+
1785+
def copy(self, a):
1786+
return a.copy()
1787+
1788+
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
1789+
return cp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
1790+
1791+
def dtype_device(self, a):
1792+
return a.dtype, a.device
1793+
1794+
def assert_same_dtype_device(self, a, b):
1795+
a_dtype, a_device = self.dtype_device(a)
1796+
b_dtype, b_device = self.dtype_device(b)
1797+
1798+
# cupy has implicit type conversion so
1799+
# we automatically validate the test for type
1800+
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"

0 commit comments

Comments
 (0)