Skip to content

Commit a08375c

Browse files
committed
Fixed all doctests assuming functions are working properly (actually tested in tests/)
1 parent 64dba52 commit a08375c

File tree

4 files changed

+104
-60
lines changed

4 files changed

+104
-60
lines changed

.travis.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ before_script: # configure a headless display to test plot generation
2626
# command to install dependencies
2727
install:
2828
- pip install -r requirements.txt
29-
- pip install numpy>=1.14 # for numpy array formatting in doctests
30-
- pip install "scipy<1.3" # otherwise, pymanopt fails, cf <https://github.com/pymanopt/pymanopt/issues/77>
29+
- pip install numpy>=1.14 "scipy<1.3" # for numpy array formatting in doctests
30+
# ^ scipy version: otherwise, pymanopt fails, cf <https://github.com/pymanopt/pymanopt/issues/77>
31+
- python -c "import numpy; import scipy; print('numpy: ', numpy.__version__); print('scipy: ', scipy.__version__)"
3132
- pip install flake8 pytest "pytest-cov<2.6"
3233
- pip install .
3334
# command to run tests + check syntax style

ot/bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1559,7 +1559,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15591559
>>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
15601560
>>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
15611561
>>> empirical_sinkhorn_divergence(X_s, X_t, reg)
1562-
array([2.99977435])
1562+
array([1.49988718])
15631563
15641564
15651565
References

ot/stochastic.py

Lines changed: 99 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,23 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
5252
Examples
5353
--------
5454
>>> import ot
55+
>>> np.random.seed(0)
5556
>>> n_source = 7
5657
>>> n_target = 4
57-
>>> reg = 1
58-
>>> numItermax = 300000
5958
>>> a = ot.utils.unif(n_source)
6059
>>> b = ot.utils.unif(n_target)
61-
>>> rng = np.random.RandomState(0)
62-
>>> X_source = rng.randn(n_source, 2)
63-
>>> Y_target = rng.randn(n_target, 2)
60+
>>> X_source = np.random.randn(n_source, 2)
61+
>>> Y_target = np.random.randn(n_target, 2)
6462
>>> M = ot.dist(X_source, Y_target)
65-
>>> method = "ASGD"
66-
>>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
67-
>>> print(asgd_pi)
63+
>>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
64+
array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
65+
[1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
66+
[3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
67+
[2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
68+
[9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
69+
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
70+
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
71+
6872
6973
References
7074
----------
@@ -133,19 +137,22 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
133137
Examples
134138
--------
135139
>>> import ot
140+
>>> np.random.seed(0)
136141
>>> n_source = 7
137142
>>> n_target = 4
138-
>>> reg = 1
139-
>>> numItermax = 300000
140143
>>> a = ot.utils.unif(n_source)
141144
>>> b = ot.utils.unif(n_target)
142-
>>> rng = np.random.RandomState(0)
143-
>>> X_source = rng.randn(n_source, 2)
144-
>>> Y_target = rng.randn(n_target, 2)
145+
>>> X_source = np.random.randn(n_source, 2)
146+
>>> Y_target = np.random.randn(n_target, 2)
145147
>>> M = ot.dist(X_source, Y_target)
146-
>>> method = "ASGD"
147-
>>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
148-
>>> print(asgd_pi)
148+
>>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
149+
array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
150+
[1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
151+
[3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
152+
[2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
153+
[9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
154+
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
155+
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
149156
150157
References
151158
----------
@@ -222,19 +229,22 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
222229
Examples
223230
--------
224231
>>> import ot
232+
>>> np.random.seed(0)
225233
>>> n_source = 7
226234
>>> n_target = 4
227-
>>> reg = 1
228-
>>> numItermax = 300000
229235
>>> a = ot.utils.unif(n_source)
230236
>>> b = ot.utils.unif(n_target)
231-
>>> rng = np.random.RandomState(0)
232-
>>> X_source = rng.randn(n_source, 2)
233-
>>> Y_target = rng.randn(n_target, 2)
237+
>>> X_source = np.random.randn(n_source, 2)
238+
>>> Y_target = np.random.randn(n_target, 2)
234239
>>> M = ot.dist(X_source, Y_target)
235-
>>> method = "ASGD"
236-
>>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
237-
>>> print(asgd_pi)
240+
>>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
241+
array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
242+
[1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
243+
[3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
244+
[2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
245+
[9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
246+
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
247+
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
238248
239249
References
240250
----------
@@ -301,19 +311,22 @@ def c_transform_entropic(b, M, reg, beta):
301311
Examples
302312
--------
303313
>>> import ot
314+
>>> np.random.seed(0)
304315
>>> n_source = 7
305316
>>> n_target = 4
306-
>>> reg = 1
307-
>>> numItermax = 300000
308317
>>> a = ot.utils.unif(n_source)
309318
>>> b = ot.utils.unif(n_target)
310-
>>> rng = np.random.RandomState(0)
311-
>>> X_source = rng.randn(n_source, 2)
312-
>>> Y_target = rng.randn(n_target, 2)
319+
>>> X_source = np.random.randn(n_source, 2)
320+
>>> Y_target = np.random.randn(n_target, 2)
313321
>>> M = ot.dist(X_source, Y_target)
314-
>>> method = "ASGD"
315-
>>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
316-
>>> print(asgd_pi)
322+
>>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
323+
array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
324+
[1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
325+
[3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
326+
[2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
327+
[9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
328+
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
329+
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
317330
318331
References
319332
----------
@@ -395,19 +408,22 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
395408
Examples
396409
--------
397410
>>> import ot
411+
>>> np.random.seed(0)
398412
>>> n_source = 7
399413
>>> n_target = 4
400-
>>> reg = 1
401-
>>> numItermax = 300000
402414
>>> a = ot.utils.unif(n_source)
403415
>>> b = ot.utils.unif(n_target)
404-
>>> rng = np.random.RandomState(0)
405-
>>> X_source = rng.randn(n_source, 2)
406-
>>> Y_target = rng.randn(n_target, 2)
416+
>>> X_source = np.random.randn(n_source, 2)
417+
>>> Y_target = np.random.randn(n_target, 2)
407418
>>> M = ot.dist(X_source, Y_target)
408-
>>> method = "ASGD"
409-
>>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
410-
>>> print(asgd_pi)
419+
>>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
420+
array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
421+
[1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
422+
[3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
423+
[2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
424+
[9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
425+
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
426+
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
411427
412428
References
413429
----------
@@ -502,22 +518,28 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
502518
Examples
503519
--------
504520
>>> import ot
521+
>>> np.random.seed(0)
505522
>>> n_source = 7
506523
>>> n_target = 4
507-
>>> reg = 1
508-
>>> numItermax = 20000
509-
>>> lr = 0.1
510-
>>> batch_size = 3
511-
>>> log = True
512524
>>> a = ot.utils.unif(n_source)
513525
>>> b = ot.utils.unif(n_target)
514-
>>> rng = np.random.RandomState(0)
515-
>>> X_source = rng.randn(n_source, 2)
516-
>>> Y_target = rng.randn(n_target, 2)
526+
>>> X_source = np.random.randn(n_source, 2)
527+
>>> Y_target = np.random.randn(n_target, 2)
517528
>>> M = ot.dist(X_source, Y_target)
518-
>>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
519-
>>> print(log['alpha'], log['beta'])
520-
>>> print(sgd_dual_pi)
529+
>>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg=1, batch_size=3, numItermax=30000, lr=0.1, log=True)
530+
>>> log['alpha']
531+
array([0.71759102, 1.57057384, 0.85576566, 0.1208211 , 0.59190466,
532+
1.197148 , 0.17805133])
533+
>>> log['beta']
534+
array([0.49741367, 0.57478564, 1.40075528, 2.75890102])
535+
>>> sgd_dual_pi
536+
array([[2.09730063e-02, 8.38169324e-02, 7.50365455e-03, 8.72731415e-09],
537+
[5.58432437e-03, 5.89881299e-04, 3.09558411e-05, 8.35469849e-07],
538+
[3.26489515e-03, 7.15536035e-02, 2.99778211e-02, 3.02601593e-10],
539+
[4.05390622e-02, 5.31085068e-02, 6.65191787e-02, 1.55812785e-06],
540+
[7.82299812e-02, 6.12099102e-03, 4.44989098e-02, 2.37719187e-03],
541+
[5.06266486e-02, 2.16230494e-03, 2.26215141e-03, 6.81514609e-04],
542+
[6.06713990e-02, 3.98139808e-02, 5.46829338e-02, 8.62371424e-06]])
521543
522544
References
523545
----------
@@ -526,7 +548,6 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
526548
International Conference on Learning Representation (2018),
527549
arXiv preprint arxiv:1711.02283.
528550
'''
529-
530551
G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] -
531552
M[batch_alpha, :][:, batch_beta]) / reg) *
532553
a[batch_alpha, None] * b[None, batch_beta])
@@ -605,8 +626,19 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
605626
>>> Y_target = rng.randn(n_target, 2)
606627
>>> M = ot.dist(X_source, Y_target)
607628
>>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
608-
>>> print(log['alpha'], log['beta'])
609-
>>> print(sgd_dual_pi)
629+
>>> log['alpha']
630+
array([0.64171798, 1.27932201, 0.78132257, 0.15638935, 0.54888354,
631+
1.03663469, 0.20595781])
632+
>>> log['beta']
633+
array([0.51207194, 0.58033189, 1.28922676, 2.26859736])
634+
>>> sgd_dual_pi
635+
array([[1.97276541e-02, 7.81248547e-02, 6.22136048e-03, 4.95442423e-09],
636+
[4.23494310e-03, 4.43286263e-04, 2.06927079e-05, 3.82389139e-07],
637+
[3.07542414e-03, 6.67897769e-02, 2.48904999e-02, 1.72030247e-10],
638+
[4.26271990e-02, 5.53375455e-02, 6.16535024e-02, 9.88812650e-07],
639+
[7.60423265e-02, 5.89585256e-03, 3.81267087e-02, 1.39458256e-03],
640+
[4.37557504e-02, 1.85189176e-03, 1.72335760e-03, 3.55491279e-04],
641+
[6.33096109e-02, 4.11683954e-02, 5.02962051e-02, 5.43097516e-06]])
610642
611643
References
612644
----------
@@ -701,8 +733,19 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
701733
>>> Y_target = rng.randn(n_target, 2)
702734
>>> M = ot.dist(X_source, Y_target)
703735
>>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
704-
>>> print(log['alpha'], log['beta'])
705-
>>> print(sgd_dual_pi)
736+
>>> log['alpha']
737+
array([0.64057733, 1.2683513 , 0.75610161, 0.16024284, 0.54926534,
738+
1.0514201 , 0.19958936])
739+
>>> log['beta']
740+
array([0.51372571, 0.58843489, 1.27993921, 2.24344807])
741+
>>> sgd_dual_pi
742+
array([[1.97377795e-02, 7.86706853e-02, 6.15682001e-03, 4.82586997e-09],
743+
[4.19566963e-03, 4.42016865e-04, 2.02777272e-05, 3.68823708e-07],
744+
[3.00379244e-03, 6.56562018e-02, 2.40462171e-02, 1.63579656e-10],
745+
[4.28626062e-02, 5.60031599e-02, 6.13193826e-02, 9.67977735e-07],
746+
[7.61972739e-02, 5.94609051e-03, 3.77886693e-02, 1.36046648e-03],
747+
[4.44810042e-02, 1.89476742e-03, 1.73285847e-03, 3.51826036e-04],
748+
[6.30118293e-02, 4.12398660e-02, 4.95148998e-02, 5.26247246e-06]])
706749
707750
References
708751
----------

ot/unbalanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
290290
>>> a=[.5, .15]
291291
>>> b=[.5, .5]
292292
>>> M=[[0., 1.],[1., 0.]]
293-
>>> ot.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
293+
>>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
294294
array([[0.52761554, 0.22392482],
295295
[0.10286295, 0.32257641]])
296296

0 commit comments

Comments
 (0)