Skip to content

Commit 97feeb3

Browse files
tgnassourflamaryagramfort
authored
[MRG] OT for Gaussian distributions (#428)
* add gaussian modules * add gaussian modules * add PR to release.md * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <[email protected]> * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <[email protected]> * Update ot/gaussian.py * Update ot/gaussian.py * add empirical bures wassertsein distance, fix docstring and test * update to fit with new networkx API * add test for jax et tf" * fix test * fix test? * add empirical_bures_wasserstein_mapping * fix docs * fix doc * fix docstring * add tgnassou to contributors * add more coverage for gaussian.py * add deprecated function * fix doc math" " * fix doc math" " * add remi flamary to authors of gaussiansmodule * fix equation Co-authored-by: Rémi Flamary <[email protected]> Co-authored-by: Alexandre Gramfort <[email protected]>
1 parent 058d275 commit 97feeb3

File tree

11 files changed

+448
-138
lines changed

11 files changed

+448
-138
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ The contributors to this library are:
4040
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
4141
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
4242
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
43+
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
4344

4445
## Acknowledgments
4546

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#### New features
66

7+
- Added Bures Wasserstein distance in `ot.gaussian` (PR ##428)
78
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
89
- Added Free Support Sinkhorn Barycenter + example (PR #387)
910
- New API for OT solver using function `ot.solve` (PR #388)

docs/source/all.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ API and modules
3131
sliced
3232
weak
3333
factored
34+
gaussian
3435

3536
.. autosummary::
3637
:toctree: ../modules/generated/

docs/source/quickstart.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ distributions. In this case there exists a close form solution given in Remark
279279
2.29 in [15]_ and the Monge mapping is an affine function and can be
280280
also computed from the covariances and means of the source and target
281281
distributions. In the case when the finite sample dataset is supposed Gaussian,
282-
we provide :any:`ot.da.OT_mapping_linear` that returns the parameters for the
282+
we provide :any:`ot.gaussian.bures_wasserstein_mapping` that returns the parameters for the
283283
Monge mapping.
284284

285285

@@ -628,7 +628,7 @@ approximate a Monge mapping from finite distributions.
628628
First note that when the source and target distributions are supposed to be Gaussian
629629
distributions, there exists a close form solution for the mapping and its an
630630
affine function [14]_ of the form :math:`T(x)=Ax+b` . In this case we provide the function
631-
:any:`ot.da.OT_mapping_linear` that returns the operator :math:`A` and vector
631+
:any:`ot.gaussian.bures_wasserstein_mapping` that returns the operator :math:`A` and vector
632632
:math:`b`. Note that if the number of samples is too small there is a parameter
633633
:code:`reg` that provides a regularization for the covariance matrix estimation.
634634

@@ -640,7 +640,7 @@ method proposed in [8]_ that estimates a continuous mapping approximating the
640640
barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for
641641
linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non-linear mapping.
642642

643-
.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.da.OT_mapping_linear
643+
.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.gaussian.bures_wasserstein_mapping
644644
:add-heading: Examples of Monge mapping estimation
645645
:heading-level: "
646646

examples/domain-adaptation/plot_otda_linear_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
# Estimate linear mapping and transport
6262
# -------------------------------------
6363

64-
Ae, be = ot.da.OT_mapping_linear(xs, xt)
64+
Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt)
6565

6666
xst = xs.dot(Ae) + be
6767

examples/gromov/plot_barycenter_fgw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
174174
# -------------------------
175175

176176
#%% Create the barycenter
177-
bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
177+
bary = nx.from_numpy_array(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
178178
for i, v in enumerate(A.ravel()):
179179
bary.add_node(i, attr_name=v)
180180

ot/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from . import weak
3636
from . import factored
3737
from . import solvers
38+
from . import gaussian
3839

3940
# OT functions
4041
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -56,7 +57,7 @@
5657

5758
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
5859
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
59-
'emd2_1d', 'wasserstein_1d', 'backend',
60+
'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
6061
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
6162
'sinkhorn_unbalanced', 'barycenter_unbalanced',
6263
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',

ot/da.py

Lines changed: 7 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from .bregman import sinkhorn, jcpot_barycenter
1818
from .lp import emd
1919
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
20-
from .utils import list_to_array, check_params, BaseEstimator
20+
from .utils import list_to_array, check_params, BaseEstimator, deprecated
2121
from .unbalanced import sinkhorn_unbalanced
22+
from .gaussian import empirical_bures_wasserstein_mapping
2223
from .optim import cg
2324
from .optim import gcg
2425

@@ -679,112 +680,7 @@ def df(G):
679680
return G, L
680681

681682

682-
def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
683-
wt=None, bias=True, log=False):
684-
r"""Return OT linear operator between samples.
685-
686-
The function estimates the optimal linear operator that aligns the two
687-
empirical distributions. This is equivalent to estimating the closed
688-
form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
689-
and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
690-
:ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in
691-
:ref:`[15] <references-OT-mapping-linear>`.
692-
693-
The linear operator from source to target :math:`M`
694-
695-
.. math::
696-
M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
697-
698-
where :
699-
700-
.. math::
701-
\mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
702-
\Sigma_s^{-1/2}
703-
704-
\mathbf{b} &= \mu_t - \mathbf{A} \mu_s
705-
706-
Parameters
707-
----------
708-
xs : array-like (ns,d)
709-
samples in the source domain
710-
xt : array-like (nt,d)
711-
samples in the target domain
712-
reg : float,optional
713-
regularization added to the diagonals of covariances (>0)
714-
ws : array-like (ns,1), optional
715-
weights for the source samples
716-
wt : array-like (ns,1), optional
717-
weights for the target samples
718-
bias: boolean, optional
719-
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
720-
log : bool, optional
721-
record log if True
722-
723-
724-
Returns
725-
-------
726-
A : (d, d) array-like
727-
Linear operator
728-
b : (1, d) array-like
729-
bias
730-
log : dict
731-
log dictionary return only if log==True in parameters
732-
733-
734-
.. _references-OT-mapping-linear:
735-
References
736-
----------
737-
.. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
738-
distributions", Journal of Optimization Theory and Applications
739-
Vol 43, 1984
740-
741-
.. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
742-
Transport", 2018.
743-
744-
745-
"""
746-
xs, xt = list_to_array(xs, xt)
747-
nx = get_backend(xs, xt)
748-
749-
d = xs.shape[1]
750-
751-
if bias:
752-
mxs = nx.mean(xs, axis=0)[None, :]
753-
mxt = nx.mean(xt, axis=0)[None, :]
754-
755-
xs = xs - mxs
756-
xt = xt - mxt
757-
else:
758-
mxs = nx.zeros((1, d), type_as=xs)
759-
mxt = nx.zeros((1, d), type_as=xs)
760-
761-
if ws is None:
762-
ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
763-
764-
if wt is None:
765-
wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
766-
767-
Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
768-
Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
769-
770-
Cs12 = nx.sqrtm(Cs)
771-
Cs_12 = nx.inv(Cs12)
772-
773-
M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
774-
775-
A = dots(Cs_12, M0, Cs_12)
776-
777-
b = mxt - nx.dot(mxs, A)
778-
779-
if log:
780-
log = {}
781-
log['Cs'] = Cs
782-
log['Ct'] = Ct
783-
log['Cs12'] = Cs12
784-
log['Cs_12'] = Cs_12
785-
return A, b, log
786-
else:
787-
return A, b
683+
OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping)
788684

789685

790686
def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5,
@@ -1378,10 +1274,10 @@ class label
13781274
self.mu_t = self.distribution_estimation(Xt)
13791275

13801276
# coupling estimation
1381-
returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
1382-
ws=nx.reshape(self.mu_s, (-1, 1)),
1383-
wt=nx.reshape(self.mu_t, (-1, 1)),
1384-
bias=self.bias, log=self.log)
1277+
returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg,
1278+
ws=nx.reshape(self.mu_s, (-1, 1)),
1279+
wt=nx.reshape(self.mu_t, (-1, 1)),
1280+
bias=self.bias, log=self.log)
13851281

13861282
# deal with the value of log
13871283
if self.log:

0 commit comments

Comments
 (0)