@@ -52,19 +52,23 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
52
52
Examples
53
53
--------
54
54
>>> import ot
55
+ >>> np.random.seed(0)
55
56
>>> n_source = 7
56
57
>>> n_target = 4
57
- >>> reg = 1
58
- >>> numItermax = 300000
59
58
>>> a = ot.utils.unif(n_source)
60
59
>>> b = ot.utils.unif(n_target)
61
- >>> rng = np.random.RandomState(0)
62
- >>> X_source = rng.randn(n_source, 2)
63
- >>> Y_target = rng.randn(n_target, 2)
60
+ >>> X_source = np.random.randn(n_source, 2)
61
+ >>> Y_target = np.random.randn(n_target, 2)
64
62
>>> M = ot.dist(X_source, Y_target)
65
- >>> method = "ASGD"
66
- >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
67
- >>> print(asgd_pi)
63
+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
64
+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
65
+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
66
+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
67
+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
68
+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
69
+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
70
+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
71
+
68
72
69
73
References
70
74
----------
@@ -133,19 +137,22 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
133
137
Examples
134
138
--------
135
139
>>> import ot
140
+ >>> np.random.seed(0)
136
141
>>> n_source = 7
137
142
>>> n_target = 4
138
- >>> reg = 1
139
- >>> numItermax = 300000
140
143
>>> a = ot.utils.unif(n_source)
141
144
>>> b = ot.utils.unif(n_target)
142
- >>> rng = np.random.RandomState(0)
143
- >>> X_source = rng.randn(n_source, 2)
144
- >>> Y_target = rng.randn(n_target, 2)
145
+ >>> X_source = np.random.randn(n_source, 2)
146
+ >>> Y_target = np.random.randn(n_target, 2)
145
147
>>> M = ot.dist(X_source, Y_target)
146
- >>> method = "ASGD"
147
- >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
148
- >>> print(asgd_pi)
148
+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
149
+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
150
+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
151
+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
152
+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
153
+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
154
+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
155
+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
149
156
150
157
References
151
158
----------
@@ -222,19 +229,22 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
222
229
Examples
223
230
--------
224
231
>>> import ot
232
+ >>> np.random.seed(0)
225
233
>>> n_source = 7
226
234
>>> n_target = 4
227
- >>> reg = 1
228
- >>> numItermax = 300000
229
235
>>> a = ot.utils.unif(n_source)
230
236
>>> b = ot.utils.unif(n_target)
231
- >>> rng = np.random.RandomState(0)
232
- >>> X_source = rng.randn(n_source, 2)
233
- >>> Y_target = rng.randn(n_target, 2)
237
+ >>> X_source = np.random.randn(n_source, 2)
238
+ >>> Y_target = np.random.randn(n_target, 2)
234
239
>>> M = ot.dist(X_source, Y_target)
235
- >>> method = "ASGD"
236
- >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
237
- >>> print(asgd_pi)
240
+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
241
+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
242
+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
243
+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
244
+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
245
+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
246
+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
247
+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
238
248
239
249
References
240
250
----------
@@ -301,19 +311,22 @@ def c_transform_entropic(b, M, reg, beta):
301
311
Examples
302
312
--------
303
313
>>> import ot
314
+ >>> np.random.seed(0)
304
315
>>> n_source = 7
305
316
>>> n_target = 4
306
- >>> reg = 1
307
- >>> numItermax = 300000
308
317
>>> a = ot.utils.unif(n_source)
309
318
>>> b = ot.utils.unif(n_target)
310
- >>> rng = np.random.RandomState(0)
311
- >>> X_source = rng.randn(n_source, 2)
312
- >>> Y_target = rng.randn(n_target, 2)
319
+ >>> X_source = np.random.randn(n_source, 2)
320
+ >>> Y_target = np.random.randn(n_target, 2)
313
321
>>> M = ot.dist(X_source, Y_target)
314
- >>> method = "ASGD"
315
- >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
316
- >>> print(asgd_pi)
322
+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
323
+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
324
+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
325
+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
326
+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
327
+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
328
+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
329
+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
317
330
318
331
References
319
332
----------
@@ -395,19 +408,22 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
395
408
Examples
396
409
--------
397
410
>>> import ot
411
+ >>> np.random.seed(0)
398
412
>>> n_source = 7
399
413
>>> n_target = 4
400
- >>> reg = 1
401
- >>> numItermax = 300000
402
414
>>> a = ot.utils.unif(n_source)
403
415
>>> b = ot.utils.unif(n_target)
404
- >>> rng = np.random.RandomState(0)
405
- >>> X_source = rng.randn(n_source, 2)
406
- >>> Y_target = rng.randn(n_target, 2)
416
+ >>> X_source = np.random.randn(n_source, 2)
417
+ >>> Y_target = np.random.randn(n_target, 2)
407
418
>>> M = ot.dist(X_source, Y_target)
408
- >>> method = "ASGD"
409
- >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
410
- >>> print(asgd_pi)
419
+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
420
+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
421
+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
422
+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
423
+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
424
+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
425
+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
426
+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
411
427
412
428
References
413
429
----------
@@ -502,22 +518,28 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
502
518
Examples
503
519
--------
504
520
>>> import ot
521
+ >>> np.random.seed(0)
505
522
>>> n_source = 7
506
523
>>> n_target = 4
507
- >>> reg = 1
508
- >>> numItermax = 20000
509
- >>> lr = 0.1
510
- >>> batch_size = 3
511
- >>> log = True
512
524
>>> a = ot.utils.unif(n_source)
513
525
>>> b = ot.utils.unif(n_target)
514
- >>> rng = np.random.RandomState(0)
515
- >>> X_source = rng.randn(n_source, 2)
516
- >>> Y_target = rng.randn(n_target, 2)
526
+ >>> X_source = np.random.randn(n_source, 2)
527
+ >>> Y_target = np.random.randn(n_target, 2)
517
528
>>> M = ot.dist(X_source, Y_target)
518
- >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
519
- >>> print(log['alpha'], log['beta'])
520
- >>> print(sgd_dual_pi)
529
+ >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg=1, batch_size=3, numItermax=30000, lr=0.1, log=True)
530
+ >>> log['alpha']
531
+ array([0.71759102, 1.57057384, 0.85576566, 0.1208211 , 0.59190466,
532
+ 1.197148 , 0.17805133])
533
+ >>> log['beta']
534
+ array([0.49741367, 0.57478564, 1.40075528, 2.75890102])
535
+ >>> sgd_dual_pi
536
+ array([[2.09730063e-02, 8.38169324e-02, 7.50365455e-03, 8.72731415e-09],
537
+ [5.58432437e-03, 5.89881299e-04, 3.09558411e-05, 8.35469849e-07],
538
+ [3.26489515e-03, 7.15536035e-02, 2.99778211e-02, 3.02601593e-10],
539
+ [4.05390622e-02, 5.31085068e-02, 6.65191787e-02, 1.55812785e-06],
540
+ [7.82299812e-02, 6.12099102e-03, 4.44989098e-02, 2.37719187e-03],
541
+ [5.06266486e-02, 2.16230494e-03, 2.26215141e-03, 6.81514609e-04],
542
+ [6.06713990e-02, 3.98139808e-02, 5.46829338e-02, 8.62371424e-06]])
521
543
522
544
References
523
545
----------
@@ -526,7 +548,6 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
526
548
International Conference on Learning Representation (2018),
527
549
arXiv preprint arxiv:1711.02283.
528
550
'''
529
-
530
551
G = - (np .exp ((alpha [batch_alpha , None ] + beta [None , batch_beta ] -
531
552
M [batch_alpha , :][:, batch_beta ]) / reg ) *
532
553
a [batch_alpha , None ] * b [None , batch_beta ])
@@ -605,8 +626,19 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
605
626
>>> Y_target = rng.randn(n_target, 2)
606
627
>>> M = ot.dist(X_source, Y_target)
607
628
>>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
608
- >>> print(log['alpha'], log['beta'])
609
- >>> print(sgd_dual_pi)
629
+ >>> log['alpha']
630
+ array([0.64171798, 1.27932201, 0.78132257, 0.15638935, 0.54888354,
631
+ 1.03663469, 0.20595781])
632
+ >>> log['beta']
633
+ array([0.51207194, 0.58033189, 1.28922676, 2.26859736])
634
+ >>> sgd_dual_pi
635
+ array([[1.97276541e-02, 7.81248547e-02, 6.22136048e-03, 4.95442423e-09],
636
+ [4.23494310e-03, 4.43286263e-04, 2.06927079e-05, 3.82389139e-07],
637
+ [3.07542414e-03, 6.67897769e-02, 2.48904999e-02, 1.72030247e-10],
638
+ [4.26271990e-02, 5.53375455e-02, 6.16535024e-02, 9.88812650e-07],
639
+ [7.60423265e-02, 5.89585256e-03, 3.81267087e-02, 1.39458256e-03],
640
+ [4.37557504e-02, 1.85189176e-03, 1.72335760e-03, 3.55491279e-04],
641
+ [6.33096109e-02, 4.11683954e-02, 5.02962051e-02, 5.43097516e-06]])
610
642
611
643
References
612
644
----------
@@ -701,8 +733,19 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
701
733
>>> Y_target = rng.randn(n_target, 2)
702
734
>>> M = ot.dist(X_source, Y_target)
703
735
>>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
704
- >>> print(log['alpha'], log['beta'])
705
- >>> print(sgd_dual_pi)
736
+ >>> log['alpha']
737
+ array([0.64057733, 1.2683513 , 0.75610161, 0.16024284, 0.54926534,
738
+ 1.0514201 , 0.19958936])
739
+ >>> log['beta']
740
+ array([0.51372571, 0.58843489, 1.27993921, 2.24344807])
741
+ >>> sgd_dual_pi
742
+ array([[1.97377795e-02, 7.86706853e-02, 6.15682001e-03, 4.82586997e-09],
743
+ [4.19566963e-03, 4.42016865e-04, 2.02777272e-05, 3.68823708e-07],
744
+ [3.00379244e-03, 6.56562018e-02, 2.40462171e-02, 1.63579656e-10],
745
+ [4.28626062e-02, 5.60031599e-02, 6.13193826e-02, 9.67977735e-07],
746
+ [7.61972739e-02, 5.94609051e-03, 3.77886693e-02, 1.36046648e-03],
747
+ [4.44810042e-02, 1.89476742e-03, 1.73285847e-03, 3.51826036e-04],
748
+ [6.30118293e-02, 4.12398660e-02, 4.95148998e-02, 5.26247246e-06]])
706
749
707
750
References
708
751
----------
0 commit comments