Skip to content

Commit 3cc99e6

Browse files
committed
better dicumentation
1 parent 140baad commit 3cc99e6

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,5 @@ This toolbox benefit a lot from open source research and we would like to thank
105105
[9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
106106

107107
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
108+
109+
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Mock(MagicMock):
2424
@classmethod
2525
def __getattr__(cls, name):
2626
return Mock()
27-
MOCK_MODULES = [ 'emd','ot.lp.emd']
27+
MOCK_MODULES = [ 'emd','ot.lp.emd_wrap']
2828
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
2929
# !!!!
3030

ot/dr.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Domain adaptation with optimal transport
3+
Dimension reduction with optimal transport
44
"""
55

6-
76
import autograd.numpy as np
87
from pymanopt.manifolds import Stiefel
98
from pymanopt import Problem
109
from pymanopt.solvers import SteepestDescent, TrustRegions
1110

1211
def dist(x1,x2):
13-
""" Compute squared euclidena distance between samples
12+
""" Compute squared euclidean distance between samples
1413
"""
1514
x1p2=np.sum(np.square(x1),1)
1615
x2p2=np.sum(np.square(x2),1)
@@ -40,18 +39,66 @@ def split_classes(X,y):
4039

4140
def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
4241
"""
43-
Wasserstein Discriminant Analysis
42+
Wasserstein Discriminant Analysis [11]_
4443
4544
The function solves the following optimization problem:
4645
4746
.. math::
48-
P = arg\min_P \frac{\sum_i W(PX^i,PX^i)}{\sum_{i,j\neq i} W(PX^i,PX^j)}
47+
P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
4948
5049
where :
51-
50+
51+
- :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
5252
- :math:`W` is entropic regularized Wasserstein distances
5353
- :math:`X^i` are samples in the dataset corresponding to class i
5454
55+
Parameters
56+
----------
57+
a : np.ndarray (ns,)
58+
samples weights in the source domain
59+
b : np.ndarray (nt,)
60+
samples in the target domain
61+
M : np.ndarray (ns,nt)
62+
loss matrix
63+
reg : float
64+
Regularization term >0
65+
numItermax : int, optional
66+
Max number of iterations
67+
stopThr : float, optional
68+
Stop threshol on error (>0)
69+
verbose : bool, optional
70+
Print information along iterations
71+
log : bool, optional
72+
record log if True
73+
74+
75+
Returns
76+
-------
77+
gamma : (ns x nt) ndarray
78+
Optimal transportation matrix for the given parameters
79+
log : dict
80+
log dictionary return only if log==True in parameters
81+
82+
Examples
83+
--------
84+
85+
>>> import ot
86+
>>> a=[.5,.5]
87+
>>> b=[.5,.5]
88+
>>> M=[[0.,1.],[1.,0.]]
89+
>>> ot.sinkhorn(a,b,M,1)
90+
array([[ 0.36552929, 0.13447071],
91+
[ 0.13447071, 0.36552929]])
92+
93+
94+
References
95+
----------
96+
97+
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
98+
99+
100+
101+
55102
"""
56103

57104
mx=np.mean(X)

0 commit comments

Comments
 (0)