Skip to content

Commit e0c935a

Browse files
author
Hicham Janati
committed
improve doc
1 parent b639e3e commit e0c935a

File tree

3 files changed

+28
-25
lines changed

3 files changed

+28
-25
lines changed

docs/source/quickstart.rst

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -577,10 +577,10 @@ Unbalanced optimal transport
577577

578578
Unbalanced OT is a relaxation of the entropy regularized OT problem where the violation of
579579
the constraint on the marginals is added to the objective of the optimization
580-
problem. The unbalanced OT metric between two histograms a and b is defined as [25]_ [10]_:
580+
problem. The unbalanced OT metric between two unbalanced histograms a and b is defined as [25]_ [10]_:
581581

582582
.. math::
583-
W_u(a, b) = \min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + reg\cdot\Omega(\gamma) + \alpha KL(\gamma 1, a) + \alpha KL(\gamma^T 1, b)
583+
W_u(a, b) = \min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
584584
585585
s.t. \quad \gamma\geq 0
586586
@@ -593,13 +593,11 @@ in :any:`ot.unbalanced`. Computing the optimal transport
593593
plan or the transport cost is similar to the balanced case. The Sinkhorn-Knopp
594594
algorithm is implemented in :any:`ot.sinkhorn_unbalanced` and :any:`ot.sinkhorn_unbalanced2`
595595
that return respectively the OT matrix and the value of the
596-
linear term. Note that the regularization parameter :math:`\alpha` in the
597-
equation above is given to those functions with the parameter :code:`reg_m`.
598-
596+
linear term.
599597

600598
.. note::
601599
The main function to solve entropic regularized UOT is :any:`ot.sinkhorn_unbalanced`.
602-
This function is a wrapper and the parameter :code:`method` help you select
600+
This function is a wrapper and the parameter :code:`method` helps you select
603601
the actual algorithm used to solve the problem:
604602

605603
+ :code:`method='sinkhorn'` calls :any:`ot.unbalanced.sinkhorn_knopp_unbalanced`

ot/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,5 @@
7777
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
7878
'emd_1d', 'emd2_1d', 'wasserstein_1d',
7979
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
80-
'sinkhorn_unbalanced', "barycenter_unbalanced"]
80+
'sinkhorn_unbalanced', 'barycenter_unbalanced',
81+
'sinkhorn_unbalanced2']

ot/bregman.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
3535
3636
- M is the (dim_a, dim_b) metric cost matrix
3737
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
38-
- a and b are source and target weights (sum to 1)
38+
- a and b are source and target weights (histograms, both sum to 1)
3939
4040
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
4141
@@ -143,7 +143,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
143143
144144
- M is the (dim_a, dim_b) metric cost matrix
145145
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
146-
- a and b are source and target weights (sum to 1)
146+
- a and b are source and target weights (histograms, both sum to 1)
147147
148148
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
149149
@@ -251,7 +251,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
251251
252252
- M is the (dim_a, dim_b) metric cost matrix
253253
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
254-
- a and b are source and target weights (sum to 1)
254+
- a and b are source and target weights (histograms, both sum to 1)
255255
256256
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
257257
@@ -432,7 +432,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
432432
433433
- M is the (dim_a, dim_b) metric cost matrix
434434
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
435-
- a and b are source and target weights (sum to 1)
435+
- a and b are source and target weights (histograms, both sum to 1)
436436
437437
438438
@@ -578,7 +578,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
578578
579579
- M is the (dim_a, dim_b) metric cost matrix
580580
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
581-
- a and b are source and target weights (sum to 1)
581+
- a and b are source and target weights (histograms, both sum to 1)
582+
582583
583584
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
584585
scaling algorithm as proposed in [2]_ but with the log stabilization
@@ -808,7 +809,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
808809
809810
- M is the (dim_a, dim_b) metric cost matrix
810811
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
811-
- a and b are source and target weights (sum to 1)
812+
- a and b are source and target weights (histograms, both sum to 1)
813+
812814
813815
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
814816
scaling algorithm as proposed in [2]_ but with the log stabilization
@@ -1229,7 +1231,6 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
12291231
absorbing = False
12301232
if (u > tau).any() or (v > tau).any():
12311233
absorbing = True
1232-
print("YEAH absorbing")
12331234
alpha = alpha + reg * np.log(np.max(u, 1))
12341235
beta = beta + reg * np.log(np.max(v, 1))
12351236
K = np.exp((alpha[:, None] + beta[None, :] -
@@ -1394,26 +1395,29 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
13941395
where :
13951396
13961397
- :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
1397-
- :math:`\mathbf{a}` is an observed distribution, :math:`\mathbf{h}_0` is aprior on unmixing
1398-
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT data fitting
1399-
- reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix for regularization
1398+
- :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
1399+
- :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
1400+
- :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
1401+
- :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior`
1402+
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting
1403+
- reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization
14001404
- :math:`\\alpha`weight data fitting and regularization
14011405
14021406
The optimization problem is solved suing the algorithm described in [4]
14031407
14041408
14051409
Parameters
14061410
----------
1407-
a : ndarray, shape (n_observed)
1408-
observed distribution
1409-
D : ndarray, shape (dim, dim)
1411+
a : ndarray, shape (dim_a)
1412+
observed distribution (histogram, sums to 1)
1413+
D : ndarray, shape (dim_a, n_atoms)
14101414
dictionary matrix
1411-
M : ndarray, shape (dim, dim)
1415+
M : ndarray, shape (dim_a, dim_a)
14121416
loss matrix
1413-
M0 : ndarray, shape (n_observed, n_observed)
1417+
M0 : ndarray, shape (n_atoms, dim_prior)
14141418
loss matrix
1415-
h0 : ndarray, shape (dim,)
1416-
prior on h
1419+
h0 : ndarray, shape (n_atoms,)
1420+
prior on the estimated unmixing h
14171421
reg : float
14181422
Regularization term >0 (Wasserstein data fitting)
14191423
reg0 : float
@@ -1432,7 +1436,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
14321436
14331437
Returns
14341438
-------
1435-
a : ndarray, shape (dim,)
1439+
h : ndarray, shape (n_atoms,)
14361440
Wasserstein barycenter
14371441
log : dict
14381442
log dictionary return only if log==True in parameters

0 commit comments

Comments
 (0)