Skip to content

Commit 0d81de9

Browse files
committed
doc da.py
1 parent 4ced742 commit 0d81de9

File tree

5 files changed

+142
-57
lines changed

5 files changed

+142
-57
lines changed

docs/source/all.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ Python modules
66
ot
77
--
88

9-
This module provide easy access to solvers for the most common OT problems
10-
119
.. automodule:: ot
1210
:members:
1311

@@ -28,6 +26,12 @@ ot.optim
2826
.. automodule:: ot.optim
2927
:members:
3028

29+
ot.da
30+
--------
31+
32+
.. automodule:: ot.da
33+
:members:
34+
3135
ot.utils
3236
--------
3337

ot/__init__.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
1-
# Python Optimal Transport toolbox
1+
"""Python Optimal Transport toolbox"""
22

33
# All submodules and packages
4-
from . import lp
4+
from . import lp
55
from . import bregman
6-
from . import optim
6+
from . import optim
77
from . import utils
88
from . import datasets
99
from . import plot
1010
from . import da
1111

12-
13-
1412
# OT functions
1513
from .lp import emd
16-
from .bregman import sinkhorn,barycenter
14+
from .bregman import sinkhorn, barycenter
1715
from .da import sinkhorn_lpl1_mm
1816

1917
# utils functions
20-
from .utils import dist,unif
18+
from .utils import dist, unif
2119

22-
__all__ = ["emd","sinkhorn","utils",'datasets','bregman','lp','plot','dist','unif','barycenter','sinkhorn_lpl1_mm','da','optim']
20+
__all__ = ["emd", "sinkhorn", "utils", 'datasets', 'bregman', 'lp', 'plot',
21+
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']

ot/bregman.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import numpy as np
77

88

9-
def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False):
9+
def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False):
1010
"""
11-
Solve the entropic regularization optimal transport problem and return the OT matrix
12-
11+
Solve the entropic regularization optimal transport problem
12+
1313
The function solves the following optimization problem:
14-
14+
1515
.. math::
1616
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
1717

ot/da.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
# -*- coding: utf-8 -*-
12
"""
2-
domain adaptation with optimal transport
3+
Domain adaptation with optimal transport
34
"""
5+
46
import numpy as np
57
from .bregman import sinkhorn
68

@@ -9,7 +11,88 @@
911
def indices(a, func):
1012
return [i for (i, val) in enumerate(a) if func(val)]
1113

12-
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1):
14+
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
15+
"""
16+
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
17+
18+
The function solves the following optimization problem:
19+
20+
.. math::
21+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
22+
23+
s.t. \gamma 1 = a
24+
25+
\gamma^T 1= b
26+
27+
\gamma\geq 0
28+
where :
29+
30+
- M is the (ns,nt) metric cost matrix
31+
- :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
32+
- :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
33+
- a and b are source and target weights (sum to 1)
34+
35+
The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
36+
37+
38+
Parameters
39+
----------
40+
a : np.ndarray (ns,)
41+
samples weights in the source domain
42+
labels_a : np.ndarray (ns,)
43+
labels of samples in the source domain
44+
b : np.ndarray (nt,)
45+
samples in the target domain
46+
M : np.ndarray (ns,nt)
47+
loss matrix
48+
reg: float
49+
Regularization term for entropic regularization >0
50+
eta: float, optional
51+
Regularization term for group lasso regularization >0
52+
numItermax: int, optional
53+
Max number of iterations
54+
numInnerItermax: int, optional
55+
Max number of iterations (inner sinkhorn solver)
56+
stopInnerThr: float, optional
57+
Stop threshold on error (inner sinkhorn solver) (>0)
58+
verbose : bool, optional
59+
Print information along iterations
60+
log : bool, optional
61+
record log if True
62+
63+
64+
Returns
65+
-------
66+
gamma: (ns x nt) ndarray
67+
Optimal transportation matrix for the given parameters
68+
log: dict
69+
log dictionary return only if log==True in parameters
70+
71+
Examples
72+
--------
73+
74+
>>> a=[.5,.5]
75+
>>> b=[.5,.5]
76+
>>> M=[[0.,1.],[1.,0.]]
77+
>>> ot.sinkhorn(a,b,M,1)
78+
array([[ 0.36552929, 0.13447071],
79+
[ 0.13447071, 0.36552929]])
80+
81+
82+
References
83+
----------
84+
85+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
86+
87+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
88+
89+
See Also
90+
--------
91+
ot.lp.emd : Unregularized OT
92+
ot.bregman.sinkhorn : Entropic regularized OT
93+
ot.optim.cg : General regularized OT
94+
95+
"""
1396
p=0.5
1497
epsilon = 1e-3
1598

@@ -25,9 +108,9 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1):
25108

26109
W=np.zeros(M.shape)
27110

28-
for cpt in range(10):
111+
for cpt in range(numItermax):
29112
Mreg = M + eta*W
30-
transp=sinkhorn(a,b,Mreg,reg,numItermax = 200)
113+
transp=sinkhorn(a,b,Mreg,reg,numItermax=numInnerItermax, stopThr=stopInnerThr)
31114
# the transport has been computed. Check if classes are really separated
32115
W = np.ones((Nini,Nfin))
33116
for t in range(Nfin):

ot/lp/__init__.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,78 @@
1+
# -*- coding: utf-8 -*-
12
"""
23
Solvers for the original linear program OT problem
34
"""
45

6+
import numpy as np
57
# import compiled emd
68
from .emd import emd_c
7-
import numpy as np
89

9-
def emd(a,b,M):
10-
"""
11-
Solves the Earth Movers distance problem and returns the optimal transport matrix
12-
13-
10+
11+
def emd(a, b, M):
12+
"""Solves the Earth Movers distance problem and returns the OT matrix
13+
14+
1415
.. math::
15-
\gamma = arg\min_\gamma <\gamma,M>_F
16-
16+
\gamma = arg\min_\gamma <\gamma,M>_F
17+
1718
s.t. \gamma 1 = a
18-
19-
\gamma^T 1= b
20-
19+
\gamma^T 1= b
2120
\gamma\geq 0
2221
where :
23-
22+
2423
- M is the metric cost matrix
2524
- a and b are the sample weights
26-
25+
2726
Uses the algorithm proposed in [1]_
28-
27+
2928
Parameters
3029
----------
3130
a : (ns,) ndarray, float64
3231
Source histogram (uniform weigth if empty list)
3332
b : (nt,) ndarray, float64
3433
Target histogram (uniform weigth if empty list)
3534
M : (ns,nt) ndarray, float64
36-
loss matrix
37-
35+
loss matrix
36+
3837
Returns
3938
-------
4039
gamma: (ns x nt) ndarray
4140
Optimal transportation matrix for the given parameters
42-
43-
41+
42+
4443
Examples
4544
--------
46-
45+
4746
Simple example with obvious solution. The function emd accepts lists and
48-
perform automatic conversion to numpy arrays
49-
47+
perform automatic conversion to numpy arrays
48+
5049
>>> a=[.5,.5]
5150
>>> b=[.5,.5]
5251
>>> M=[[0.,1.],[1.,0.]]
5352
>>> ot.emd(a,b,M)
5453
array([[ 0.5, 0. ],
5554
[ 0. , 0.5]])
56-
55+
5756
References
5857
----------
59-
60-
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
61-
58+
59+
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
60+
(2011, December). Displacement interpolation using Lagrangian mass
61+
transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
62+
158). ACM.
63+
6264
See Also
6365
--------
6466
ot.bregman.sinkhorn : Entropic regularized OT
65-
ot.optim.cg : General regularized OT
66-
67-
68-
"""
69-
a=np.asarray(a,dtype=np.float64)
70-
b=np.asarray(b,dtype=np.float64)
71-
M=np.asarray(M,dtype=np.float64)
72-
73-
if len(a)==0:
74-
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
75-
if len(b)==0:
76-
b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
77-
78-
return emd_c(a,b,M)
67+
ot.optim.cg : General regularized OT"""
68+
69+
a = np.asarray(a, dtype=np.float64)
70+
b = np.asarray(b, dtype=np.float64)
71+
M = np.asarray(M, dtype=np.float64)
72+
73+
if len(a) == 0:
74+
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
75+
if len(b) == 0:
76+
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
7977

78+
return emd_c(a, b, M)

0 commit comments

Comments
 (0)