8
8
#
9
9
# License: MIT License
10
10
11
- import numpy as np
12
- import cudamat
11
+ import cupy as np # np used for matrix computation
12
+ import cupy as cp # cp used for cupy specific operations
13
+ from . import utils
13
14
14
15
15
- def sinkhorn (a , b , M_GPU , reg , numItermax = 1000 , stopThr = 1e-9 , verbose = False ,
16
- log = False , returnAsGPU = False ):
17
- r"""
18
- Solve the entropic regularization optimal transport problem on GPU
16
+
17
+ def sinkhorn_knopp (a , b , M , reg , numItermax = 1000 , stopThr = 1e-9 ,
18
+ verbose = False , log = False , to_numpy = True , ** kwargs ):
19
+ """
20
+ Solve the entropic regularization optimal transport problem and return the OT matrix
19
21
20
22
The function solves the following optimization problem:
21
23
@@ -40,9 +42,10 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
40
42
----------
41
43
a : np.ndarray (ns,)
42
44
samples weights in the source domain
43
- b : np.ndarray (nt,)
44
- samples in the target domain
45
- M_GPU : cudamat.CUDAMatrix (ns,nt)
45
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
46
+ samples in the target domain, compute sinkhorn with multiple targets
47
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
48
+ M : np.ndarray (ns,nt)
46
49
loss matrix
47
50
reg : float
48
51
Regularization term >0
@@ -54,8 +57,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
54
57
Print information along iterations
55
58
log : bool, optional
56
59
record log if True
57
- returnAsGPU : bool, optional
58
- return the OT matrix as a cudamat.CUDAMatrix
60
+
59
61
60
62
Returns
61
63
-------
@@ -88,60 +90,78 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
88
90
ot.optim.cg : General regularized OT
89
91
90
92
"""
93
+
94
+ a = cp .asarray (a , dtype = np .float64 )
95
+ b = cp .asarray (b , dtype = np .float64 )
96
+ M = cp .asarray (M , dtype = np .float64 )
97
+
98
+ if len (a ) == 0 :
99
+ a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
100
+ if len (b ) == 0 :
101
+ b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
102
+
91
103
# init data
92
104
Nini = len (a )
93
105
Nfin = len (b )
94
106
107
+ if len (b .shape ) > 1 :
108
+ nbb = b .shape [1 ]
109
+ else :
110
+ nbb = 0
111
+
95
112
if log :
96
113
log = {'err' : []}
97
114
98
115
# we assume that no distances are null except those of the diagonal of
99
116
# distances
100
- u = (np .ones (Nini ) / Nini ).reshape ((Nini , 1 ))
101
- u_GPU = cudamat .CUDAMatrix (u )
102
- a_GPU = cudamat .CUDAMatrix (a .reshape ((Nini , 1 )))
103
- ones_GPU = cudamat .empty (u_GPU .shape ).assign (1 )
104
- v = (np .ones (Nfin ) / Nfin ).reshape ((Nfin , 1 ))
105
- v_GPU = cudamat .CUDAMatrix (v )
106
- b_GPU = cudamat .CUDAMatrix (b .reshape ((Nfin , 1 )))
107
-
108
- M_GPU .divide (- reg )
117
+ if nbb :
118
+ u = np .ones ((Nini , nbb )) / Nini
119
+ v = np .ones ((Nfin , nbb )) / Nfin
120
+ else :
121
+ u = np .ones (Nini ) / Nini
122
+ v = np .ones (Nfin ) / Nfin
109
123
110
- K_GPU = cudamat . exp ( M_GPU )
124
+ # print(reg )
111
125
112
- ones_GPU .divide (a_GPU , target = a_GPU )
113
- Kp_GPU = cudamat .empty (K_GPU .shape )
114
- K_GPU .mult_by_col (a_GPU , target = Kp_GPU )
126
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
127
+ K = np .empty (M .shape , dtype = M .dtype )
128
+ np .divide (M , - reg , out = K )
129
+ np .exp (K , out = K )
115
130
116
- tmp_GPU = cudamat .empty (K_GPU .shape )
131
+ # print(np.min(K))
132
+ tmp2 = np .empty (b .shape , dtype = M .dtype )
117
133
134
+ Kp = (1 / a ).reshape (- 1 , 1 ) * K
118
135
cpt = 0
119
136
err = 1
120
137
while (err > stopThr and cpt < numItermax ):
121
- uprev_GPU = u_GPU . copy ()
122
- vprev_GPU = v_GPU . copy ()
138
+ uprev = u
139
+ vprev = v
123
140
124
- KtransposeU_GPU = K_GPU . transpose (). dot (u_GPU )
125
- b_GPU .divide (KtransposeU_GPU , target = v_GPU )
126
- ones_GPU . divide ( Kp_GPU .dot (v_GPU ), target = u_GPU )
141
+ KtransposeU = np . dot (K . T , u )
142
+ v = np .divide (b , KtransposeU )
143
+ u = 1. / np .dot (Kp , v )
127
144
128
- if (np .any (KtransposeU_GPU .asarray () == 0 ) or
129
- not u_GPU .allfinite () or not v_GPU .allfinite ()):
145
+ if (np .any (KtransposeU == 0 ) or
146
+ np .any (np .isnan (u )) or np .any (np .isnan (v )) or
147
+ np .any (np .isinf (u )) or np .any (np .isinf (v ))):
130
148
# we have reached the machine precision
131
149
# come back to previous solution and quit loop
132
150
print ('Warning: numerical errors at iteration' , cpt )
133
- u_GPU = uprev_GPU . copy ()
134
- v_GPU = vprev_GPU . copy ()
151
+ u = uprev
152
+ v = vprev
135
153
break
136
154
if cpt % 10 == 0 :
137
155
# we can speed up the process by checking for the error only all
138
156
# the 10th iterations
139
- K_GPU .mult_by_col (u_GPU , target = tmp_GPU )
140
- tmp_GPU .mult_by_row (v_GPU .transpose (), target = tmp_GPU )
141
-
142
- bcopy_GPU = b_GPU .copy ().transpose ()
143
- bcopy_GPU .add_sums (tmp_GPU , axis = 0 , beta = - 1 )
144
- err = bcopy_GPU .euclid_norm ()** 2
157
+ if nbb :
158
+ err = np .sum ((u - uprev )** 2 ) / np .sum ((u )** 2 ) + \
159
+ np .sum ((v - vprev )** 2 ) / np .sum ((v )** 2 )
160
+ else :
161
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
162
+ tmp2 = np .sum (u [:,None ]* K * v [None ,:],0 )
163
+ #tmp2=np.einsum('i,ij,j->j', u, K, v)
164
+ err = np .linalg .norm (tmp2 - b )** 2 # violation of marginal
145
165
if log :
146
166
log ['err' ].append (err )
147
167
@@ -150,20 +170,31 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
150
170
print (
151
171
'{:5s}|{:12s}' .format ('It.' , 'Err' ) + '\n ' + '-' * 19 )
152
172
print ('{:5d}|{:8e}|' .format (cpt , err ))
153
- cpt += 1
154
- if log :
155
- log ['u' ] = u_GPU .asarray ()
156
- log ['v' ] = v_GPU .asarray ()
157
-
158
- K_GPU .mult_by_col (u_GPU , target = K_GPU )
159
- K_GPU .mult_by_row (v_GPU .transpose (), target = K_GPU )
160
-
161
- if returnAsGPU :
162
- res = K_GPU
163
- else :
164
- res = K_GPU .asarray ()
165
-
173
+ cpt = cpt + 1
166
174
if log :
167
- return res , log
168
- else :
169
- return res
175
+ log ['u' ] = u
176
+ log ['v' ] = v
177
+
178
+ if nbb : # return only loss
179
+ #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory)
180
+ res = np .empty (nbb )
181
+ for i in range (nbb ):
182
+ res [i ]= np .sum (u [:,None ,i ]* (K * M )* v [None ,:,i ])
183
+ if to_numpy :
184
+ res = utils .to_np (res )
185
+ if log :
186
+ return res , log
187
+ else :
188
+ return res
189
+
190
+ else : # return OT matrix
191
+ res = u .reshape ((- 1 , 1 )) * K * v .reshape ((1 , - 1 ))
192
+ if to_numpy :
193
+ res = utils .to_np (res )
194
+ if log :
195
+ return res , log
196
+ else :
197
+ return res
198
+
199
+ # define sinkhorn as sinkhorn_knopp
200
+ sinkhorn = sinkhorn_knopp
0 commit comments