Skip to content

Commit f45f7a6

Browse files
committed
pep8
1 parent d258c7d commit f45f7a6

File tree

3 files changed

+34
-36
lines changed

3 files changed

+34
-36
lines changed

ot/gpu/bregman.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
#
99
# License: MIT License
1010

11-
import cupy as np # np used for matrix computation
12-
import cupy as cp # cp used for cupy specific operations
11+
import cupy as np # np used for matrix computation
12+
import cupy as cp # cp used for cupy specific operations
1313
from . import utils
1414

1515

16-
1716
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
1817
verbose=False, log=False, to_numpy=True, **kwargs):
1918
"""
@@ -159,7 +158,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
159158
np.sum((v - vprev)**2) / np.sum((v)**2)
160159
else:
161160
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
162-
tmp2=np.sum(u[:,None]*K*v[None,:],0)
161+
tmp2 = np.sum(u[:, None] * K * v[None, :], 0)
163162
#tmp2=np.einsum('i,ij,j->j', u, K, v)
164163
err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
165164
if log:
@@ -177,24 +176,25 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
177176

178177
if nbb: # return only loss
179178
#res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory)
180-
res=np.empty(nbb)
179+
res = np.empty(nbb)
181180
for i in range(nbb):
182-
res[i]=np.sum(u[:,None,i]*(K*M)*v[None,:,i])
181+
res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i])
183182
if to_numpy:
184-
res=utils.to_np(res)
183+
res = utils.to_np(res)
185184
if log:
186185
return res, log
187186
else:
188187
return res
189188

190189
else: # return OT matrix
191-
res=u.reshape((-1, 1)) * K * v.reshape((1, -1))
190+
res = u.reshape((-1, 1)) * K * v.reshape((1, -1))
192191
if to_numpy:
193-
res=utils.to_np(res)
192+
res = utils.to_np(res)
194193
if log:
195194
return res, log
196195
else:
197196
return res
198197

198+
199199
# define sinkhorn as sinkhorn_knopp
200-
sinkhorn=sinkhorn_knopp
200+
sinkhorn = sinkhorn_knopp

ot/gpu/da.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
#
1111
# License: MIT License
1212

13-
import cupy as np # np used for matrix computation
14-
import cupy as cp # cp used for cupy specific operations
13+
import cupy as np # np used for matrix computation
14+
import cupy as cp # cp used for cupy specific operations
1515
import numpy as npp
1616
from . import utils
1717
from .bregman import sinkhorn
1818

19+
1920
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
2021
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
21-
log=False,to_numpy=True):
22+
log=False, to_numpy=True):
2223
"""
2324
Solve the entropic regularization optimal transport problem with nonconvex
2425
group lasso regularization
@@ -101,15 +102,14 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
101102
ot.optim.cg : General regularized OT
102103
103104
"""
104-
105+
105106
a, labels_a, b, M = utils.to_gpu(a, labels_a, b, M)
106-
107-
107+
108108
p = 0.5
109109
epsilon = 1e-3
110110

111111
indices_labels = []
112-
labels_a2=cp.asnumpy(labels_a)
112+
labels_a2 = cp.asnumpy(labels_a)
113113
classes = npp.unique(labels_a2)
114114
for c in classes:
115115
idxc, = utils.to_gpu(npp.where(labels_a2 == c))
@@ -120,7 +120,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
120120
for cpt in range(numItermax):
121121
Mreg = M + eta * W
122122
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
123-
stopThr=stopInnerThr,to_numpy=False)
123+
stopThr=stopInnerThr, to_numpy=False)
124124
# the transport has been computed. Check if classes are really
125125
# separated
126126
W = np.ones(M.shape)

ot/gpu/utils.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Utility functions for GPU
3+
Utility functions for GPU
44
"""
55

66
# Author: Remi Flamary <[email protected]>
@@ -9,9 +9,8 @@
99
#
1010
# License: MIT License
1111

12-
import cupy as np # np used for matrix computation
13-
import cupy as cp # cp used for cupy specific operations
14-
12+
import cupy as np # np used for matrix computation
13+
import cupy as cp # cp used for cupy specific operations
1514

1615

1716
def euclidean_distances(a, b, squared=False, to_numpy=True):
@@ -34,23 +33,24 @@ def euclidean_distances(a, b, squared=False, to_numpy=True):
3433
c : (n x m) np.ndarray or cupy.ndarray
3534
pairwise euclidean distance distance matrix
3635
"""
37-
36+
3837
a, b = to_gpu(a, b)
39-
40-
a2=np.sum(np.square(a),1)
41-
b2=np.sum(np.square(b),1)
42-
43-
c=-2*np.dot(a,b.T)
44-
c+=a2[:,None]
45-
c+=b2[None,:]
46-
38+
39+
a2 = np.sum(np.square(a), 1)
40+
b2 = np.sum(np.square(b), 1)
41+
42+
c = -2 * np.dot(a, b.T)
43+
c += a2[:, None]
44+
c += b2[None, :]
45+
4746
if not squared:
4847
np.sqrt(c, out=c)
4948
if to_numpy:
5049
return to_np(c)
5150
else:
5251
return c
5352

53+
5454
def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True):
5555
"""Compute distance between samples in x1 and x2 on gpu
5656
@@ -61,8 +61,8 @@ def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True):
6161
matrix with n1 samples of size d
6262
x2 : np.array (n2,d), optional
6363
matrix with n2 samples of size d (if None then x2=x1)
64-
metric : str
65-
Metric from 'sqeuclidean', 'euclidean',
64+
metric : str
65+
Metric from 'sqeuclidean', 'euclidean',
6666
6767
6868
Returns
@@ -80,7 +80,6 @@ def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True):
8080
return euclidean_distances(x1, x2, squared=False, to_numpy=to_numpy)
8181
else:
8282
raise NotImplementedError
83-
8483

8584

8685
def to_gpu(*args):
@@ -91,10 +90,9 @@ def to_gpu(*args):
9190
return cp.asarray(args[0])
9291

9392

94-
9593
def to_np(*args):
9694
""" convert GPU arras to numpy and return them"""
9795
if len(args) > 1:
9896
return (cp.asnumpy(x) for x in args)
9997
else:
100-
return cp.asnumpy(args[0])
98+
return cp.asnumpy(args[0])

0 commit comments

Comments
 (0)