Skip to content

Commit 8b41e14

Browse files
committed
add log and epsilon scaling stabilizations
1 parent e485078 commit 8b41e14

File tree

5 files changed

+362
-9
lines changed

5 files changed

+362
-9
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ This open source Python library provide several solvers for optimization problem
88
It provides the following solvers:
99

1010
* OT solver for the linear program/ Earth Movers Distance [1].
11-
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2].
11+
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10].
1212
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
1313
* Optimal transport for domain adaptation with group lasso regularization [5]
1414
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
@@ -98,3 +98,7 @@ This toolbox benefit a lot from open source research and we would like to thank
9898
[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
9999

100100
[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
101+
102+
[9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
103+
104+
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.

examples/demo_OT_1D.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
x=np.arange(n,dtype=np.float64)
2020

2121
# Gaussian distributions
22-
a=gauss(n,m=20,s=20) # m= mean, s= std
23-
b=gauss(n,m=60,s=60)
22+
a=gauss(n,m=20,s=5) # m= mean, s= std
23+
b=gauss(n,m=60,s=10)
2424

2525
# loss matrix
2626
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))

examples/demo_optim_OTreg.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
x=np.arange(n,dtype=np.float64)
1818

1919
# Gaussian distributions
20-
a=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std
21-
b=ot.datasets.get_1D_gauss(n,m=60,s=60)
20+
a=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
21+
b=ot.datasets.get_1D_gauss(n,m=60,s=10)
2222

2323
# loss matrix
2424
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
@@ -37,7 +37,7 @@ def f(G): return 0.5*np.sum(G**2)
3737
def df(G): return G
3838

3939
reg=1e-1
40-
40+
4141
Gl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True)
4242

4343
pl.figure(3)
@@ -47,9 +47,9 @@ def df(G): return G
4747

4848
def f(G): return np.sum(G*np.log(G))
4949
def df(G): return np.log(G)+1
50-
50+
5151
reg=1e-3
52-
52+
5353
Ge=ot.optim.cg(a,b,M,reg,f,df,verbose=True)
5454

5555
pl.figure(4)

0 commit comments

Comments
 (0)