Skip to content

Commit b6fb148

Browse files
committed
more
1 parent 1b00740 commit b6fb148

File tree

3 files changed

+111
-138
lines changed

3 files changed

+111
-138
lines changed

ot/stochastic.py

Lines changed: 90 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,20 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
3838
3939
Parameters
4040
----------
41-
42-
b : np.ndarray(nt,)
43-
target measure
44-
M : np.ndarray(ns, nt)
45-
cost matrix
46-
reg : float nu
47-
Regularization term > 0
48-
v : np.ndarray(nt,)
49-
dual variable
50-
i : number int
51-
picked number i
41+
b : ndarray, shape (nt,)
42+
Target measure.
43+
M : ndarray, shape (ns, nt)
44+
Cost matrix.
45+
reg : float
46+
Regularization term > 0.
47+
v : ndarray, shape (nt,)
48+
Dual variable.
49+
i : int
50+
Picked number i.
5251
5352
Returns
5453
-------
55-
56-
coordinate gradient : np.ndarray(nt,)
54+
coordinate gradient : ndarray, shape (nt,)
5755
5856
Examples
5957
--------
@@ -78,14 +76,11 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
7876
7977
References
8078
----------
81-
8279
[Genevay et al., 2016] :
83-
Stochastic Optimization for Large-scale Optimal Transport,
84-
Advances in Neural Information Processing Systems (2016),
85-
arXiv preprint arxiv:1605.08527.
86-
80+
Stochastic Optimization for Large-scale Optimal Transport,
81+
Advances in Neural Information Processing Systems (2016),
82+
arXiv preprint arxiv:1605.08527.
8783
'''
88-
8984
r = M[i, :] - beta
9085
exp_beta = np.exp(-r / reg) * b
9186
khi = exp_beta / (np.sum(exp_beta))
@@ -121,24 +116,23 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
121116
Parameters
122117
----------
123118
124-
a : np.ndarray(ns,),
125-
source measure
126-
b : np.ndarray(nt,),
127-
target measure
128-
M : np.ndarray(ns, nt),
129-
cost matrix
130-
reg : float number,
119+
a : ndarray, shape (ns,),
120+
Source measure.
121+
b : ndarray, shape (nt,),
122+
Target measure.
123+
M : ndarray, shape (ns, nt),
124+
Cost matrix.
125+
reg : float
131126
Regularization term > 0
132-
numItermax : int number
133-
number of iteration
134-
lr : float number
135-
learning rate
127+
numItermax : int
128+
Number of iteration.
129+
lr : float
130+
Learning rate.
136131
137132
Returns
138133
-------
139-
140-
v : np.ndarray(nt,)
141-
dual variable
134+
v : ndarray, shape (nt,)
135+
Dual variable.
142136
143137
Examples
144138
--------
@@ -213,23 +207,20 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
213207
214208
Parameters
215209
----------
216-
217-
b : np.ndarray(nt,)
210+
b : ndarray, shape (nt,)
218211
target measure
219-
M : np.ndarray(ns, nt)
212+
M : ndarray, shape (ns, nt)
220213
cost matrix
221-
reg : float number
214+
reg : float
222215
Regularization term > 0
223-
numItermax : int number
224-
number of iteration
225-
lr : float number
226-
learning rate
227-
216+
numItermax : int
217+
Number of iteration.
218+
lr : float
219+
Learning rate.
228220
229221
Returns
230222
-------
231-
232-
ave_v : np.ndarray(nt,)
223+
ave_v : ndarray, shape (nt,)
233224
dual variable
234225
235226
Examples
@@ -256,9 +247,9 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
256247
----------
257248
258249
[Genevay et al., 2016] :
259-
Stochastic Optimization for Large-scale Optimal Transport,
260-
Advances in Neural Information Processing Systems (2016),
261-
arXiv preprint arxiv:1605.08527.
250+
Stochastic Optimization for Large-scale Optimal Transport,
251+
Advances in Neural Information Processing Systems (2016),
252+
arXiv preprint arxiv:1605.08527.
262253
'''
263254

264255
if lr is None:
@@ -298,21 +289,19 @@ def c_transform_entropic(b, M, reg, beta):
298289
299290
Parameters
300291
----------
301-
302-
b : np.ndarray(nt,)
303-
target measure
304-
M : np.ndarray(ns, nt)
305-
cost matrix
292+
b : ndarray, shape (nt,)
293+
Target measure
294+
M : ndarray, shape (ns, nt)
295+
Cost matrix
306296
reg : float
307-
regularization term > 0
308-
v : np.ndarray(nt,)
309-
dual variable
297+
Regularization term > 0
298+
v : ndarray, shape (nt,)
299+
Dual variable.
310300
311301
Returns
312302
-------
313-
314-
u : np.ndarray(ns,)
315-
dual variable
303+
u : ndarray, shape (ns,)
304+
Dual variable.
316305
317306
Examples
318307
--------
@@ -338,9 +327,9 @@ def c_transform_entropic(b, M, reg, beta):
338327
----------
339328
340329
[Genevay et al., 2016] :
341-
Stochastic Optimization for Large-scale Optimal Transport,
342-
Advances in Neural Information Processing Systems (2016),
343-
arXiv preprint arxiv:1605.08527.
330+
Stochastic Optimization for Large-scale Optimal Transport,
331+
Advances in Neural Information Processing Systems (2016),
332+
arXiv preprint arxiv:1605.08527.
344333
'''
345334

346335
n_source = np.shape(M)[0]
@@ -382,31 +371,30 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
382371
Parameters
383372
----------
384373
385-
a : np.ndarray(ns,)
374+
a : ndarray, shape (ns,)
386375
source measure
387-
b : np.ndarray(nt,)
376+
b : ndarray, shape (nt,)
388377
target measure
389-
M : np.ndarray(ns, nt)
378+
M : ndarray, shape (ns, nt)
390379
cost matrix
391-
reg : float number
380+
reg : float
392381
Regularization term > 0
393382
methode : str
394383
used method (SAG or ASGD)
395-
numItermax : int number
384+
numItermax : int
396385
number of iteration
397-
lr : float number
386+
lr : float
398387
learning rate
399-
n_source : int number
388+
n_source : int
400389
size of the source measure
401-
n_target : int number
390+
n_target : int
402391
size of the target measure
403392
log : bool, optional
404393
record log if True
405394
406395
Returns
407396
-------
408-
409-
pi : np.ndarray(ns, nt)
397+
pi : ndarray, shape (ns, nt)
410398
transportation matrix
411399
log : dict
412400
log dictionary return only if log==True in parameters
@@ -495,30 +483,28 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
495483
496484
Parameters
497485
----------
498-
499-
a : np.ndarray(ns,)
486+
a : ndarray, shape (ns,)
500487
source measure
501-
b : np.ndarray(nt,)
488+
b : ndarray, shape (nt,)
502489
target measure
503-
M : np.ndarray(ns, nt)
490+
M : ndarray, shape (ns, nt)
504491
cost matrix
505-
reg : float number
492+
reg : float
506493
Regularization term > 0
507-
alpha : np.ndarray(ns,)
494+
alpha : ndarray, shape (ns,)
508495
dual variable
509-
beta : np.ndarray(nt,)
496+
beta : ndarray, shape (nt,)
510497
dual variable
511-
batch_size : int number
498+
batch_size : int
512499
size of the batch
513-
batch_alpha : np.ndarray(bs,)
500+
batch_alpha : ndarray, shape (bs,)
514501
batch of index of alpha
515-
batch_beta : np.ndarray(bs,)
502+
batch_beta : ndarray, shape (bs,)
516503
batch of index of beta
517504
518505
Returns
519506
-------
520-
521-
grad : np.ndarray(ns,)
507+
grad : ndarray, shape (ns,)
522508
partial grad F
523509
524510
Examples
@@ -591,28 +577,26 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
591577
592578
Parameters
593579
----------
594-
595-
a : np.ndarray(ns,)
580+
a : ndarray, shape (ns,)
596581
source measure
597-
b : np.ndarray(nt,)
582+
b : ndarray, shape (nt,)
598583
target measure
599-
M : np.ndarray(ns, nt)
584+
M : ndarray, shape (ns, nt)
600585
cost matrix
601-
reg : float number
586+
reg : float
602587
Regularization term > 0
603-
batch_size : int number
588+
batch_size : int
604589
size of the batch
605-
numItermax : int number
590+
numItermax : int
606591
number of iteration
607-
lr : float number
592+
lr : float
608593
learning rate
609594
610595
Returns
611596
-------
612-
613-
alpha : np.ndarray(ns,)
597+
alpha : ndarray, shape (ns,)
614598
dual variable
615-
beta : np.ndarray(nt,)
599+
beta : ndarray, shape (nt,)
616600
dual variable
617601
618602
Examples
@@ -648,10 +632,9 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
648632
649633
References
650634
----------
651-
652635
[Seguy et al., 2018] :
653-
International Conference on Learning Representation (2018),
654-
arXiv preprint arxiv:1711.02283.
636+
International Conference on Learning Representation (2018),
637+
arXiv preprint arxiv:1711.02283.
655638
'''
656639

657640
n_source = np.shape(M)[0]
@@ -696,28 +679,26 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
696679
697680
Parameters
698681
----------
699-
700-
a : np.ndarray(ns,)
682+
a : ndarray, shape (ns,)
701683
source measure
702-
b : np.ndarray(nt,)
684+
b : ndarray, shape (nt,)
703685
target measure
704-
M : np.ndarray(ns, nt)
686+
M : ndarray, shape (ns, nt)
705687
cost matrix
706-
reg : float number
688+
reg : float
707689
Regularization term > 0
708-
batch_size : int number
690+
batch_size : int
709691
size of the batch
710-
numItermax : int number
692+
numItermax : int
711693
number of iteration
712-
lr : float number
694+
lr : float
713695
learning rate
714696
log : bool, optional
715697
record log if True
716698
717699
Returns
718700
-------
719-
720-
pi : np.ndarray(ns, nt)
701+
pi : ndarray, shape (ns, nt)
721702
transportation matrix
722703
log : dict
723704
log dictionary return only if log==True in parameters
@@ -757,8 +738,8 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
757738
----------
758739
759740
[Seguy et al., 2018] :
760-
International Conference on Learning Representation (2018),
761-
arXiv preprint arxiv:1711.02283.
741+
International Conference on Learning Representation (2018),
742+
arXiv preprint arxiv:1711.02283.
762743
'''
763744

764745
opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,

ot/unbalanced.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
380380
print(
381381
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
382382
print('{:5d}|{:8e}|'.format(cpt, err))
383-
cpt = cpt + 1
383+
cpt += 1
384+
384385
if log:
385386
log['u'] = u
386387
log['v'] = v

0 commit comments

Comments
 (0)