Skip to content

Commit fd6371c

Browse files
Kilian FatrasKilian Fatras
authored andcommitted
replaced marginal tests
1 parent b2b5ffc commit fd6371c

File tree

3 files changed

+31
-34
lines changed

3 files changed

+31
-34
lines changed

ot/lp/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from ..utils import parmap
1919
from .cvx import barycenter
2020

21-
__all__=['emd', 'emd2', 'barycenter', 'cvx']
22-
2321

2422
def emd(a, b, M, numItermax=100000, log=False):
2523
"""Solves the Earth Movers distance problem and returns the OT matrix

ot/stochastic.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -450,24 +450,29 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
450450
451451
\forall j in batch_alpha,
452452
grad_beta_j = beta_j * batch_size/len(alpha) -
453-
sum_{j in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
453+
sum_{i in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
454454
* a_i * b_j
455455
where :
456456
- M is the (ns,nt) metric cost matrix
457457
- alpha, beta are dual variables in R^ixR^J
458458
- reg is the regularization term
459-
- batch_alpha and batch_beta are list of index
459+
- batch_alpha and batch_beta are lists of index
460+
- a and b are source and target weights (sum to 1)
461+
460462
461463
The algorithm used for solving the dual problem is the SGD algorithm
462464
as proposed in [19]_ [alg.1]
463465
464466
Parameters
465467
----------
466-
467-
reg : float number,
468-
Regularization term > 0
468+
a : np.ndarray(ns,),
469+
source measure
470+
b : np.ndarray(nt,),
471+
target measure
469472
M : np.ndarray(ns, nt),
470473
cost matrix
474+
reg : float number,
475+
Regularization term > 0
471476
alpha : np.ndarray(ns,)
472477
dual variable
473478
beta : np.ndarray(nt,)
@@ -516,8 +521,8 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
516521
'''
517522

518523
G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] -
519-
M[batch_alpha, :][:, batch_beta]) / reg) * a[batch_alpha, None] *
520-
b[None, batch_beta])
524+
M[batch_alpha, :][:, batch_beta]) / reg) *
525+
a[batch_alpha, None] * b[None, batch_beta])
521526
grad_beta = np.zeros(np.shape(M)[1])
522527
grad_alpha = np.zeros(np.shape(M)[0])
523528
grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0] +
@@ -548,23 +553,20 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
548553
549554
Parameters
550555
----------
551-
556+
a : np.ndarray(ns,),
557+
source measure
558+
b : np.ndarray(nt,),
559+
target measure
552560
M : np.ndarray(ns, nt),
553561
cost matrix
554562
reg : float number,
555563
Regularization term > 0
556-
alpha : np.ndarray(ns,)
557-
dual variable
558-
beta : np.ndarray(nt,)
559-
dual variable
560564
batch_size : int number
561565
size of the batch
562566
numItermax : int number
563567
number of iteration
564568
lr : float number
565569
learning rate
566-
alternate : bool, optional
567-
alternating algorithm
568570
569571
Returns
570572
-------
@@ -591,8 +593,8 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
591593
>>> Y_target = rng.randn(n_target, 2)
592594
>>> M = ot.dist(X_source, Y_target)
593595
>>> sgd_dual_pi, log = stochastic.solve_dual_entropic(a, b, M, reg,
594-
batch_size,
595-
numItermax, lr, log)
596+
batch_size,
597+
numItermax, lr, log)
596598
>>> print(log['alpha'], log['beta'])
597599
>>> print(sgd_dual_pi)
598600
@@ -609,7 +611,7 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
609611
cur_alpha = np.zeros(n_source)
610612
cur_beta = np.zeros(n_target)
611613
for cur_iter in range(numItermax):
612-
k = np.sqrt(cur_iter / 100 + 1)
614+
k = np.sqrt(cur_iter + 1)
613615
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
614616
batch_beta = np.random.choice(n_target, batch_size, replace=False)
615617
update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,

test/test_stochastic.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def test_sag_asgd_sinkhorn():
9797

9898
x = rng.randn(n, 2)
9999
u = ot.utils.unif(n)
100-
zero = np.zeros(n)
101100
M = ot.dist(x, x)
102101

103102
G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
@@ -108,13 +107,13 @@ def test_sag_asgd_sinkhorn():
108107

109108
# check constratints
110109
np.testing.assert_allclose(
111-
zero, (G_sag - G_sinkhorn).sum(1), atol=1e-03) # cf convergence sag
110+
G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03)
112111
np.testing.assert_allclose(
113-
zero, (G_sag - G_sinkhorn).sum(0), atol=1e-03) # cf convergence sag
112+
G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03)
114113
np.testing.assert_allclose(
115-
zero, (G_asgd - G_sinkhorn).sum(1), atol=1e-03) # cf convergence asgd
114+
G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
116115
np.testing.assert_allclose(
117-
zero, (G_asgd - G_sinkhorn).sum(0), atol=1e-03) # cf convergence asgd
116+
G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
118117
np.testing.assert_allclose(
119118
G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag
120119
np.testing.assert_allclose(
@@ -151,9 +150,9 @@ def test_stochastic_dual_sgd():
151150

152151
# check constratints
153152
np.testing.assert_allclose(
154-
u, G.sum(1), atol=1e-04) # cf convergence sgd
153+
u, G.sum(1), atol=1e-03) # cf convergence sgd
155154
np.testing.assert_allclose(
156-
u, G.sum(0), atol=1e-04) # cf convergence sgd
155+
u, G.sum(0), atol=1e-03) # cf convergence sgd
157156

158157

159158
#############################################################################
@@ -175,7 +174,6 @@ def test_dual_sgd_sinkhorn():
175174
# Test uniform
176175
x = rng.randn(n, 2)
177176
u = ot.utils.unif(n)
178-
zero = np.zeros(n)
179177
M = ot.dist(x, x)
180178

181179
G_sgd = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
@@ -185,17 +183,16 @@ def test_dual_sgd_sinkhorn():
185183

186184
# check constratints
187185
np.testing.assert_allclose(
188-
zero, (G_sgd - G_sinkhorn).sum(1), atol=1e-04) # cf convergence sgd
186+
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
189187
np.testing.assert_allclose(
190-
zero, (G_sgd - G_sinkhorn).sum(0), atol=1e-04) # cf convergence sgd
188+
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
191189
np.testing.assert_allclose(
192-
G_sgd, G_sinkhorn, atol=1e-04) # cf convergence sgd
190+
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
193191

194192
# Test gaussian
195193
n = 30
196194
reg = 1
197195
batch_size = 30
198-
zero = np.zeros(n)
199196

200197
a = ot.datasets.make_1D_gauss(n, 15, 5) # m= mean, s= std
201198
b = ot.datasets.make_1D_gauss(n, 15, 5)
@@ -211,8 +208,8 @@ def test_dual_sgd_sinkhorn():
211208

212209
# check constratints
213210
np.testing.assert_allclose(
214-
zero, (G_sgd - G_sinkhorn).sum(1), atol=1e-04) # cf convergence sgd
211+
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
215212
np.testing.assert_allclose(
216-
zero, (G_sgd - G_sinkhorn).sum(0), atol=1e-04) # cf convergence sgd
213+
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
217214
np.testing.assert_allclose(
218-
G_sgd, G_sinkhorn, atol=1e-04) # cf convergence sgd
215+
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd

0 commit comments

Comments
 (0)