3
3
Multi-lib backend for POT
4
4
5
5
The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
6
- or Jax , POT code should work nonetheless.
6
+ Jax, or Cupy , POT code should work nonetheless.
7
7
To achieve that, POT provides backend classes which implements functions in their respective backend
8
8
imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
9
9
44
44
jax = False
45
45
jax_type = float
46
46
47
+ try :
48
+ import cupy as cp
49
+ import cupyx
50
+ cp_type = cp .ndarray
51
+ except ImportError :
52
+ cp = False
53
+ cp_type = float
54
+
47
55
str_type_error = "All array should be from the same type/backend. Current types are : {}"
48
56
49
57
@@ -57,6 +65,9 @@ def get_backend_list():
57
65
if jax :
58
66
lst .append (JaxBackend ())
59
67
68
+ if cp :
69
+ lst .append (CupyBackend ())
70
+
60
71
return lst
61
72
62
73
@@ -78,6 +89,8 @@ def get_backend(*args):
78
89
return TorchBackend ()
79
90
elif isinstance (args [0 ], jax_type ):
80
91
return JaxBackend ()
92
+ elif isinstance (args [0 ], cp_type ):
93
+ return CupyBackend ()
81
94
else :
82
95
raise ValueError ("Unknown type of non implemented backend." )
83
96
@@ -94,7 +107,8 @@ def to_numpy(*args):
94
107
class Backend ():
95
108
"""
96
109
Backend abstract class.
97
- Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
110
+ Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`,
111
+ :py:class:`CupyBackend`
98
112
99
113
- The `__name__` class attribute refers to the name of the backend.
100
114
- The `__type__` class attribute refers to the data structure used by the backend.
@@ -1500,3 +1514,287 @@ def assert_same_dtype_device(self, a, b):
1500
1514
1501
1515
assert a_dtype == b_dtype , "Dtype discrepancy"
1502
1516
assert a_device == b_device , f"Device discrepancy. First input is on { str (a_device )} , whereas second input is on { str (b_device )} "
1517
+
1518
+
1519
+ class CupyBackend (Backend ): # pragma: no cover
1520
+ """
1521
+ CuPy implementation of the backend
1522
+
1523
+ - `__name__` is "cupy"
1524
+ - `__type__` is cp.ndarray
1525
+ """
1526
+
1527
+ __name__ = 'cupy'
1528
+ __type__ = cp_type
1529
+ __type_list__ = None
1530
+
1531
+ rng_ = None
1532
+
1533
+ def __init__ (self ):
1534
+ self .rng_ = cp .random .RandomState ()
1535
+
1536
+ self .__type_list__ = [
1537
+ cp .array (1 , dtype = cp .float32 ),
1538
+ cp .array (1 , dtype = cp .float64 )
1539
+ ]
1540
+
1541
+ def to_numpy (self , a ):
1542
+ return cp .asnumpy (a )
1543
+
1544
+ def from_numpy (self , a , type_as = None ):
1545
+ if type_as is None :
1546
+ return cp .asarray (a )
1547
+ else :
1548
+ with cp .cuda .Device (type_as .device ):
1549
+ return cp .asarray (a , dtype = type_as .dtype )
1550
+
1551
+ def set_gradients (self , val , inputs , grads ):
1552
+ # No gradients for cupy
1553
+ return val
1554
+
1555
+ def zeros (self , shape , type_as = None ):
1556
+ if isinstance (shape , (list , tuple )):
1557
+ shape = tuple (int (i ) for i in shape )
1558
+ if type_as is None :
1559
+ return cp .zeros (shape )
1560
+ else :
1561
+ with cp .cuda .Device (type_as .device ):
1562
+ return cp .zeros (shape , dtype = type_as .dtype )
1563
+
1564
+ def ones (self , shape , type_as = None ):
1565
+ if isinstance (shape , (list , tuple )):
1566
+ shape = tuple (int (i ) for i in shape )
1567
+ if type_as is None :
1568
+ return cp .ones (shape )
1569
+ else :
1570
+ with cp .cuda .Device (type_as .device ):
1571
+ return cp .ones (shape , dtype = type_as .dtype )
1572
+
1573
+ def arange (self , stop , start = 0 , step = 1 , type_as = None ):
1574
+ return cp .arange (start , stop , step )
1575
+
1576
+ def full (self , shape , fill_value , type_as = None ):
1577
+ if isinstance (shape , (list , tuple )):
1578
+ shape = tuple (int (i ) for i in shape )
1579
+ if type_as is None :
1580
+ return cp .full (shape , fill_value )
1581
+ else :
1582
+ with cp .cuda .Device (type_as .device ):
1583
+ return cp .full (shape , fill_value , dtype = type_as .dtype )
1584
+
1585
+ def eye (self , N , M = None , type_as = None ):
1586
+ if type_as is None :
1587
+ return cp .eye (N , M )
1588
+ else :
1589
+ with cp .cuda .Device (type_as .device ):
1590
+ return cp .eye (N , M , dtype = type_as .dtype )
1591
+
1592
+ def sum (self , a , axis = None , keepdims = False ):
1593
+ return cp .sum (a , axis , keepdims = keepdims )
1594
+
1595
+ def cumsum (self , a , axis = None ):
1596
+ return cp .cumsum (a , axis )
1597
+
1598
+ def max (self , a , axis = None , keepdims = False ):
1599
+ return cp .max (a , axis , keepdims = keepdims )
1600
+
1601
+ def min (self , a , axis = None , keepdims = False ):
1602
+ return cp .min (a , axis , keepdims = keepdims )
1603
+
1604
+ def maximum (self , a , b ):
1605
+ return cp .maximum (a , b )
1606
+
1607
+ def minimum (self , a , b ):
1608
+ return cp .minimum (a , b )
1609
+
1610
+ def abs (self , a ):
1611
+ return cp .abs (a )
1612
+
1613
+ def exp (self , a ):
1614
+ return cp .exp (a )
1615
+
1616
+ def log (self , a ):
1617
+ return cp .log (a )
1618
+
1619
+ def sqrt (self , a ):
1620
+ return cp .sqrt (a )
1621
+
1622
+ def power (self , a , exponents ):
1623
+ return cp .power (a , exponents )
1624
+
1625
+ def dot (self , a , b ):
1626
+ return cp .dot (a , b )
1627
+
1628
+ def norm (self , a ):
1629
+ return cp .sqrt (cp .sum (cp .square (a )))
1630
+
1631
+ def any (self , a ):
1632
+ return cp .any (a )
1633
+
1634
+ def isnan (self , a ):
1635
+ return cp .isnan (a )
1636
+
1637
+ def isinf (self , a ):
1638
+ return cp .isinf (a )
1639
+
1640
+ def einsum (self , subscripts , * operands ):
1641
+ return cp .einsum (subscripts , * operands )
1642
+
1643
+ def sort (self , a , axis = - 1 ):
1644
+ return cp .sort (a , axis )
1645
+
1646
+ def argsort (self , a , axis = - 1 ):
1647
+ return cp .argsort (a , axis )
1648
+
1649
+ def searchsorted (self , a , v , side = 'left' ):
1650
+ if a .ndim == 1 :
1651
+ return cp .searchsorted (a , v , side )
1652
+ else :
1653
+ # this is a not very efficient way to make numpy
1654
+ # searchsorted work on 2d arrays
1655
+ ret = cp .empty (v .shape , dtype = int )
1656
+ for i in range (a .shape [0 ]):
1657
+ ret [i , :] = cp .searchsorted (a [i , :], v [i , :], side )
1658
+ return ret
1659
+
1660
+ def flip (self , a , axis = None ):
1661
+ return cp .flip (a , axis )
1662
+
1663
+ def outer (self , a , b ):
1664
+ return cp .outer (a , b )
1665
+
1666
+ def clip (self , a , a_min , a_max ):
1667
+ return cp .clip (a , a_min , a_max )
1668
+
1669
+ def repeat (self , a , repeats , axis = None ):
1670
+ return cp .repeat (a , repeats , axis )
1671
+
1672
+ def take_along_axis (self , arr , indices , axis ):
1673
+ return cp .take_along_axis (arr , indices , axis )
1674
+
1675
+ def concatenate (self , arrays , axis = 0 ):
1676
+ return cp .concatenate (arrays , axis )
1677
+
1678
+ def zero_pad (self , a , pad_width ):
1679
+ return cp .pad (a , pad_width )
1680
+
1681
+ def argmax (self , a , axis = None ):
1682
+ return cp .argmax (a , axis = axis )
1683
+
1684
+ def mean (self , a , axis = None ):
1685
+ return cp .mean (a , axis = axis )
1686
+
1687
+ def std (self , a , axis = None ):
1688
+ return cp .std (a , axis = axis )
1689
+
1690
+ def linspace (self , start , stop , num ):
1691
+ return cp .linspace (start , stop , num )
1692
+
1693
+ def meshgrid (self , a , b ):
1694
+ return cp .meshgrid (a , b )
1695
+
1696
+ def diag (self , a , k = 0 ):
1697
+ return cp .diag (a , k )
1698
+
1699
+ def unique (self , a ):
1700
+ return cp .unique (a )
1701
+
1702
+ def logsumexp (self , a , axis = None ):
1703
+ # Taken from
1704
+ # https://github.com/scipy/scipy/blob/v1.7.1/scipy/special/_logsumexp.py#L7-L127
1705
+ a_max = cp .amax (a , axis = axis , keepdims = True )
1706
+
1707
+ if a_max .ndim > 0 :
1708
+ a_max [~ cp .isfinite (a_max )] = 0
1709
+ elif not cp .isfinite (a_max ):
1710
+ a_max = 0
1711
+
1712
+ tmp = cp .exp (a - a_max )
1713
+ s = cp .sum (tmp , axis = axis )
1714
+ out = cp .log (s )
1715
+ a_max = cp .squeeze (a_max , axis = axis )
1716
+ out += a_max
1717
+ return out
1718
+
1719
+ def stack (self , arrays , axis = 0 ):
1720
+ return cp .stack (arrays , axis )
1721
+
1722
+ def reshape (self , a , shape ):
1723
+ return cp .reshape (a , shape )
1724
+
1725
+ def seed (self , seed = None ):
1726
+ if seed is not None :
1727
+ self .rng_ .seed (seed )
1728
+
1729
+ def rand (self , * size , type_as = None ):
1730
+ if type_as is None :
1731
+ return self .rng_ .rand (* size )
1732
+ else :
1733
+ with cp .cuda .Device (type_as .device ):
1734
+ return self .rng_ .rand (* size , dtype = type_as .dtype )
1735
+
1736
+ def randn (self , * size , type_as = None ):
1737
+ if type_as is None :
1738
+ return self .rng_ .randn (* size )
1739
+ else :
1740
+ with cp .cuda .Device (type_as .device ):
1741
+ return self .rng_ .randn (* size , dtype = type_as .dtype )
1742
+
1743
+ def coo_matrix (self , data , rows , cols , shape = None , type_as = None ):
1744
+ data = self .from_numpy (data )
1745
+ rows = self .from_numpy (rows )
1746
+ cols = self .from_numpy (cols )
1747
+ if type_as is None :
1748
+ return cupyx .scipy .sparse .coo_matrix (
1749
+ (data , (rows , cols )), shape = shape
1750
+ )
1751
+ else :
1752
+ with cp .cuda .Device (type_as .device ):
1753
+ return cupyx .scipy .sparse .coo_matrix (
1754
+ (data , (rows , cols )), shape = shape , dtype = type_as .dtype
1755
+ )
1756
+
1757
+ def issparse (self , a ):
1758
+ return cupyx .scipy .sparse .issparse (a )
1759
+
1760
+ def tocsr (self , a ):
1761
+ if self .issparse (a ):
1762
+ return a .tocsr ()
1763
+ else :
1764
+ return cupyx .scipy .sparse .csr_matrix (a )
1765
+
1766
+ def eliminate_zeros (self , a , threshold = 0. ):
1767
+ if threshold > 0 :
1768
+ if self .issparse (a ):
1769
+ a .data [self .abs (a .data ) <= threshold ] = 0
1770
+ else :
1771
+ a [self .abs (a ) <= threshold ] = 0
1772
+ if self .issparse (a ):
1773
+ a .eliminate_zeros ()
1774
+ return a
1775
+
1776
+ def todense (self , a ):
1777
+ if self .issparse (a ):
1778
+ return a .toarray ()
1779
+ else :
1780
+ return a
1781
+
1782
+ def where (self , condition , x , y ):
1783
+ return cp .where (condition , x , y )
1784
+
1785
+ def copy (self , a ):
1786
+ return a .copy ()
1787
+
1788
+ def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
1789
+ return cp .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
1790
+
1791
+ def dtype_device (self , a ):
1792
+ return a .dtype , a .device
1793
+
1794
+ def assert_same_dtype_device (self , a , b ):
1795
+ a_dtype , a_device = self .dtype_device (a )
1796
+ b_dtype , b_device = self .dtype_device (b )
1797
+
1798
+ # cupy has implicit type conversion so
1799
+ # we automatically validate the test for type
1800
+ assert a_device == b_device , f"Device discrepancy. First input is on { str (a_device )} , whereas second input is on { str (b_device )} "
0 commit comments