Skip to content

Commit 8c724ad

Browse files
committed
partial with tests
1 parent fff2463 commit 8c724ad

File tree

4 files changed

+39
-149
lines changed

4 files changed

+39
-149
lines changed

examples/plot_partial_wass_and_gromov.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@
3333
cov = np.array([[1, 0], [0, 2]])
3434

3535
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
36-
xs = np.append(xs, (np.random.rand(n_noise, 2)+1)*4).reshape((-1, 2))
36+
xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
3737
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
38-
xt = np.append(xt, (np.random.rand(n_noise, 2)+1)*-3).reshape((-1, 2))
38+
xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
3939

4040
M = sp.spatial.distance.cdist(xs, xt)
4141

@@ -62,7 +62,7 @@
6262
log=True)
6363

6464
print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist']))
65-
print('Entropic partial Wasserstein distance (m = 0.5): ' + \
65+
print('Entropic partial Wasserstein distance (m = 0.5): ' +
6666
str(log['partial_w_dist']))
6767

6868
pl.figure(1, (10, 5))
@@ -98,10 +98,10 @@
9898

9999

100100
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
101-
xs = np.concatenate((xs, ((np.random.rand(n_noise, 2)+1)*4)), axis=0)
101+
xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
102102
P = sp.linalg.sqrtm(cov_t)
103103
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
104-
xt = np.concatenate((xt, ((np.random.rand(n_noise, 3)+1)*10)), axis=0)
104+
xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
105105

106106
fig = pl.figure()
107107
ax1 = fig.add_subplot(121)
@@ -128,7 +128,7 @@
128128
m=m, log=True)
129129

130130
print('Partial Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
131-
print('Entropic partial Wasserstein distance (m = 1): ' + \
131+
print('Entropic partial Wasserstein distance (m = 1): ' +
132132
str(log['partial_gw_dist']))
133133

134134
pl.figure(1, (10, 5))
@@ -142,14 +142,14 @@
142142
pl.show()
143143

144144
print('-----m = 2/3')
145-
m = 2/3
145+
m = 2 / 3
146146
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
147147
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
148148
m=m, log=True)
149149

150-
print('Partial Wasserstein distance (m = 2/3): ' + \
150+
print('Partial Wasserstein distance (m = 2/3): ' +
151151
str(log0['partial_gw_dist']))
152-
print('Entropic partial Wasserstein distance (m = 2/3): ' + \
152+
print('Entropic partial Wasserstein distance (m = 2/3): ' +
153153
str(log['partial_gw_dist']))
154154

155155
pl.figure(1, (10, 5))

ot/__init__.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

ot/partial.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99

1010
import numpy as np
1111

12-
from ot.lp import emd
12+
from .lp import emd
1313

1414

1515
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
1616
**kwargs):
17-
1817
r"""
1918
Solves the partial optimal transport problem for the quadratic cost
2019
and returns the OT plan
@@ -136,7 +135,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
136135
if log_emd['warning'] is not None:
137136
raise ValueError("Error in the EMD resolution: try to increase the"
138137
" number of dummy points")
139-
log_emd['cost'] = np.sum(gamma*M)
138+
log_emd['cost'] = np.sum(gamma * M)
140139
if log:
141140
return gamma, log_emd
142141
else:
@@ -233,7 +232,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
233232

234233
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
235234
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
236-
M_extended = np.ones((len(a_extended), len(b_extended))) * np.max(M) * 1e2
235+
M_extended = np.ones((len(a_extended), len(b_extended))) * 0
237236
M_extended[-1, -1] = np.max(M) * 1e5
238237
M_extended[:len(a), :len(b)] = M
239238

@@ -381,7 +380,7 @@ def gwloss_partial(C1, C2, T):
381380
382381
Returns
383382
-------
384-
GW loss
383+
GW loss
385384
"""
386385
g = gwgrad_partial(C1, C2, T) * 0.5
387386
return np.sum(g * T)
@@ -432,7 +431,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
432431
G0 : ndarray, shape (ns, nt), optional
433432
Initialisation of the transportation matrix
434433
thres : float, optional
435-
quantile of the gradient matrix to populate the cost matrix when 0
434+
quantile of the gradient matrix to populate the cost matrix when 0
436435
(default: 1)
437436
numItermax : int, optional
438437
Max number of iterations
@@ -566,7 +565,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
566565
where :
567566
568567
- M is the metric cost matrix
569-
- :math:`\Omega` is the entropic regularization term
568+
- :math:`\Omega` is the entropic regularization term
570569
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
571570
- a and b are the sample weights
572571
- m is the amount of mass to be transported
@@ -591,7 +590,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
591590
G0 : ndarray, shape (ns, nt), optional
592591
Initialisation of the transportation matrix
593592
thres : float, optional
594-
quantile of the gradient matrix to populate the cost matrix when 0
593+
quantile of the gradient matrix to populate the cost matrix when 0
595594
(default: 1)
596595
numItermax : int, optional
597596
Max number of iterations
@@ -666,7 +665,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
666665
where :
667666
668667
- M is the metric cost matrix
669-
- :math:`\Omega` is the entropic regularization term
668+
- :math:`\Omega` is the entropic regularization term
670669
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
671670
- a and b are the sample weights
672671
- m is the amount of mass to be transported
@@ -754,7 +753,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
754753
K = np.empty(M.shape, dtype=M.dtype)
755754
np.divide(M, -reg, out=K)
756755
np.exp(K, out=K)
757-
np.multiply(K, m/np.sum(K), out=K)
756+
np.multiply(K, m / np.sum(K), out=K)
758757

759758
err, cpt = 1, 0
760759

@@ -809,7 +808,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
809808
- C2 is the metric cost matrix in the target space
810809
- p and q are the sample weights
811810
- L : quadratic loss function
812-
- :math:`\Omega` is the entropic regularization term
811+
- :math:`\Omega` is the entropic regularization term
813812
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
814813
- m is the amount of mass to be transported
815814
@@ -944,7 +943,7 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
944943
- C2 is the metric cost matrix in the target space
945944
- p and q are the sample weights
946945
- L : quadratic loss function
947-
- :math:`\Omega` is the entropic regularization term
946+
- :math:`\Omega` is the entropic regularization term
948947
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
949948
- m is the amount of mass to be transported
950949

ot/unbalanced.py

Lines changed: 19 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# from .utils import unif, dist
1515

1616

17-
def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numItermax=1000,
17+
def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
1818
stopThr=1e-6, verbose=False, log=False, **kwargs):
1919
r"""
2020
Solve the unbalanced entropic regularization optimal transport problem
@@ -120,20 +120,20 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numI
120120
"""
121121

122122
if method.lower() == 'sinkhorn':
123-
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div,
123+
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
124124
numItermax=numItermax,
125125
stopThr=stopThr, verbose=verbose,
126126
log=log, **kwargs)
127127

128128
elif method.lower() == 'sinkhorn_stabilized':
129-
return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div,
129+
return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
130130
numItermax=numItermax,
131131
stopThr=stopThr,
132132
verbose=verbose,
133133
log=log, **kwargs)
134134
elif method.lower() in ['sinkhorn_reg_scaling']:
135135
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
136-
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg,
136+
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
137137
numItermax=numItermax,
138138
stopThr=stopThr, verbose=verbose,
139139
log=log, **kwargs)
@@ -261,8 +261,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
261261
else:
262262
raise ValueError('Unknown method %s.' % method)
263263

264-
# TODO: update the doc
265-
def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000,
264+
265+
def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
266266
stopThr=1e-6, verbose=False, log=False, **kwargs):
267267
r"""
268268
Solve the entropic regularization unbalanced optimal transport problem and return the loss
@@ -349,7 +349,6 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000,
349349
"""
350350

351351
a = np.asarray(a, dtype=np.float64)
352-
print(a)
353352
b = np.asarray(b, dtype=np.float64)
354353
M = np.asarray(M, dtype=np.float64)
355354

@@ -377,39 +376,24 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000,
377376
else:
378377
u = np.ones(dim_a) / dim_a
379378
v = np.ones(dim_b) / dim_b
380-
u = np.ones(dim_a)
381-
v = np.ones(dim_b)
382379

383380
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
384381
K = np.empty(M.shape, dtype=M.dtype)
385-
np.true_divide(M, -reg, out=K)
382+
np.divide(M, -reg, out=K)
386383
np.exp(K, out=K)
387-
388-
if div == "KL":
389-
fi = reg_m / (reg_m + reg)
390-
elif div == "TV":
391-
fi = reg_m / reg
384+
385+
fi = reg_m / (reg_m + reg)
392386

393387
err = 1.
394-
395-
dx = np.ones(dim_a) / dim_a
396-
dy = np.ones(dim_b) / dim_b
397-
z = 1
398388

399389
for i in range(numItermax):
400390
uprev = u
401391
vprev = v
402392

403-
Kv = z*K.dot(v*dy)
404-
u = scaling_iter_prox(Kv, a, fi, div)
405-
#u = (a / Kv) ** fi
406-
Ktu = z*K.T.dot(u*dx)
407-
v = scaling_iter_prox(Ktu, b, fi, div)
408-
#v = (b / Ktu) ** fi
409-
#print(v*dy)
410-
z = np.dot((u*dx).T, np.dot(K,v*dy))/0.35
411-
print(z)
412-
393+
Kv = K.dot(v)
394+
u = (a / Kv) ** fi
395+
Ktu = K.T.dot(u)
396+
v = (b / Ktu) ** fi
413397

414398
if (np.any(Ktu == 0.)
415399
or np.any(np.isnan(u)) or np.any(np.isnan(v))
@@ -450,12 +434,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000,
450434
if log:
451435
return u[:, None] * K * v[None, :], log
452436
else:
453-
return z*u[:, None] * K * v[None, :]
437+
return u[:, None] * K * v[None, :]
438+
454439

455-
# TODO: update the doc
456-
def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5,
457-
numItermax=1000, stopThr=1e-6,
458-
verbose=False, log=False, **kwargs):
440+
def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
441+
stopThr=1e-6, verbose=False, log=False,
442+
**kwargs):
459443
r"""
460444
Solve the entropic regularization unbalanced optimal transport
461445
problem and return the loss
@@ -580,10 +564,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5,
580564
np.divide(M, -reg, out=K)
581565
np.exp(K, out=K)
582566

583-
if div == "KL":
584-
fi = reg_m / (reg_m + reg)
585-
elif div == "TV":
586-
fi = reg_m / reg
567+
fi = reg_m / (reg_m + reg)
587568

588569
cpt = 0
589570
err = 1.
@@ -669,15 +650,6 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5,
669650
else:
670651
return ot_matrix
671652

672-
def scaling_iter_prox(s, p, fi, div):
673-
if div == "KL":
674-
return (p / s) ** fi
675-
elif div == "TV":
676-
return np.minimum(s*np.exp(fi), np.maximum(s*np.exp(-fi), p)) / s
677-
else:
678-
raise ValueError("Unknown divergence '%s'." % div)
679-
680-
681653

682654
def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
683655
numItermax=1000, stopThr=1e-6,

0 commit comments

Comments
 (0)