Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] implementation of low rank ot via factor relaxation paper #719

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Automatic PR labeling and release file update check (PR #704)
- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
- Fix documentation in the module `ot.gaussian` (PR #718)
- Implement low rank through Factor Relaxation with Latent Coupling (PR #719)

#### Closed issues
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
Expand Down
12 changes: 12 additions & 0 deletions ot/low_rank/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
"""
Low rank Solvers
"""

# Author: Yessin Moakher <[email protected]>
#
# License: MIT License

from ._factor_relaxation import solve_balanced_FRLC

__all__ = ["solve_balanced_FRLC"]
225 changes: 225 additions & 0 deletions ot/low_rank/_factor_relaxation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# -*- coding: utf-8 -*-
"""
Low rank Solvers
"""

# Author: Yessin Moakher <[email protected]>
#
# License: MIT License

from ..utils import list_to_array
from ..backend import get_backend
from ..bregman import sinkhorn
from ..unbalanced import sinkhorn_unbalanced


def _initialize_couplings(a, b, r, nx, reg_init=1, random_state=42):
"""Initialize the couplings Q, R, T for the Factor Relaxation algorithm."""

n = a.shape[0]
m = b.shape[0]

Check warning on line 20 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L19-L20

Added lines #L19 - L20 were not covered by tests

nx.seed(seed=random_state)
M_Q = nx.rand(n, r, type_as=a)
M_R = nx.rand(m, r, type_as=a)
M_T = nx.rand(r, r, type_as=a)

Check warning on line 25 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L22-L25

Added lines #L22 - L25 were not covered by tests

g_Q, g_R = (

Check warning on line 27 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L27

Added line #L27 was not covered by tests
nx.full(r, 1 / r, type_as=a),
nx.full(r, 1 / r, type_as=a),
) # Shape (r,) and (r,)

Q = sinkhorn(a, g_Q, M_Q, reg_init, method="sinkhorn_log")
R = sinkhorn(b, g_R, M_R, reg_init, method="sinkhorn_log")
T = sinkhorn(

Check warning on line 34 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L32-L34

Added lines #L32 - L34 were not covered by tests
nx.dot(Q.T, nx.ones(n, type_as=a)),
nx.dot(R.T, nx.ones(m, type_as=a)),
M_T,
reg_init,
method="sinkhorn_log",
)

return Q, R, T

Check warning on line 42 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L42

Added line #L42 was not covered by tests


def _compute_gradient_Q(M, Q, R, X, g_Q, nx):
"""Compute the gradient of the loss with respect to Q."""

n = Q.shape[0]

Check warning on line 48 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L48

Added line #L48 was not covered by tests

term1 = nx.dot(

Check warning on line 50 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L50

Added line #L50 was not covered by tests
nx.dot(M, R), X.T
) # The order of multiplications is important because r<<min{n,m}
term2 = nx.diag(nx.dot(nx.dot(term1.T, Q), nx.diag(1 / g_Q))).reshape(1, -1)
term3 = nx.dot(nx.ones((n, 1), type_as=M), term2)
grad_Q = term1 - term3

Check warning on line 55 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L53-L55

Added lines #L53 - L55 were not covered by tests

return grad_Q

Check warning on line 57 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L57

Added line #L57 was not covered by tests


def _compute_gradient_R(M, Q, R, X, g_R, nx):
"""Compute the gradient of the loss with respect to R."""

m = R.shape[0]

Check warning on line 63 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L63

Added line #L63 was not covered by tests

term1 = nx.dot(nx.dot(M.T, Q), X)
term2 = nx.diag(nx.dot(nx.diag(1 / g_R), nx.dot(R.T, term1))).reshape(1, -1)
term3 = nx.dot(nx.ones((m, 1), type_as=M), term2)
grad_R = term1 - term3

Check warning on line 68 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L65-L68

Added lines #L65 - L68 were not covered by tests

return grad_R

Check warning on line 70 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L70

Added line #L70 was not covered by tests


def _compute_gradient_T(Q, R, M, g_Q, g_R, nx):
"""Compute the gradient of the loss with respect to T."""

term_1 = nx.dot(nx.dot(Q.T, M), R)
return nx.dot(nx.dot(nx.diag(1 / g_Q), term_1), nx.diag(1 / g_R))

Check warning on line 77 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L76-L77

Added lines #L76 - L77 were not covered by tests


def _compute_distance(Q_new, R_new, T_new, Q, R, T, nx):
"""Compute the distance between the new and the old couplings."""

return (

Check warning on line 83 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L83

Added line #L83 was not covered by tests
nx.sum((Q_new - Q) ** 2) + nx.sum((R_new - R) ** 2) + nx.sum((T_new - T) ** 2)
)


def solve_balanced_FRLC(
a,
b,
M,
r,
tau,
gamma,
stopThr=1e-7,
numItermax=1000,
log=False,
):
r"""
Solve the low-rank balanced optimal transport problem using Factor Relaxation
with Latent Coupling and return the OT matrix.

The function solves the following optimization problem:

.. math::
\textbf{P} = \mathop{\arg \min}_P \quad \langle \textbf{P}, \mathbf{M} \rangle_F

\text{s.t.} \textbf{P} = \textbf{Q} \operatorname{diag}(1/g_Q)\textbf{T}\operatorname{diag}(1/g_R)\textbf{R}^T

\textbf{Q} \in \Pi_{a,\cdot}, \quad \textbf{R} \in \Pi_{b,\cdot}, \quad \textbf{T} \in \Pi_{g_Q,g_R}

\textbf{Q} \in \mathbb{R}^+_{n,r},\textbf{R} \in \mathbb{R}^+_{m,r},\textbf{T} \in \mathbb{R}^+_{r,r}

where:

- :math:`\mathbf{M}` is the given cost matrix.
- :math:`g_Q := \mathbf{Q}^T 1_n, \quad g_R := \mathbf{R}^T 1_m`.
- :math:`\Pi_a, \cdot := \left\{ \mathbf{P} \mid \mathbf{P} 1_m = a \right\}, \quad \Pi_{\cdot, b} := \left\{ \mathbf{P} \mid \mathbf{P}^T 1_n = b \right\}, \quad \Pi_{a,b} := \Pi_{a, \cdot} \cap \Pi_{\cdot, b}`.


Parameters
----------
a : array-like, shape (n,)
samples weights in the source domain
b : array-like, shape (m,)
samples in the target domain
M : array-like, shape (n, m)
loss matrix
r : int
Rank constraint for the transport plan P.
tau : float
Regularization parameter controlling the relaxation of the inner marginals.
gamma : float
Step size (learning rate) for the coordinate mirror descent algorithm.
numItermax : int, optional
Max number of iterations for the mirror descent optimization.
stopThr : float, optional
Stop threshold on error (>0)
log : bool, optional
Print cost value at each iteration.

Returns
-------
P : array-like, shape (n, m)
The computed low-rank optimal transportion matrix.

References
----------
[1] Halmos, P., Liu, X., Gold, J., & Raphael, B. (2024). Low-Rank Optimal Transport through Factor Relaxation with Latent Coupling.
In Proceedings of the Thirty-eighth Annual Conference on Neural Information Processing Systems (NeurIPS 2024).
"""

a, b, M = list_to_array(a, b, M)

Check warning on line 153 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L153

Added line #L153 was not covered by tests

nx = get_backend(M, a, b)

Check warning on line 155 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L155

Added line #L155 was not covered by tests

n, m = M.shape

Check warning on line 157 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L157

Added line #L157 was not covered by tests

ones_n, ones_m = (

Check warning on line 159 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L159

Added line #L159 was not covered by tests
nx.ones(n, type_as=M),
nx.ones(m, type_as=M),
) # Shape (n,) and (m,)

Q, R, T = _initialize_couplings(a, b, r, nx) # Shape (n,r), (m,r), (r,r)
g_Q, g_R = nx.dot(Q.T, ones_n), nx.dot(R.T, ones_m) # Shape (r,) and (r,)
X = nx.dot(nx.dot(nx.diag(1 / g_Q), T), nx.diag(1 / g_R)) # Shape (r,r)

Check warning on line 166 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L164-L166

Added lines #L164 - L166 were not covered by tests

for i in range(numItermax):
grad_Q = _compute_gradient_Q(M, Q, R, X, g_Q, nx) # Shape (n,r)
grad_R = _compute_gradient_R(M, Q, R, X, g_R, nx) # Shape (m,r)

Check warning on line 170 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L168-L170

Added lines #L168 - L170 were not covered by tests

gamma_k = gamma / max(

Check warning on line 172 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L172

Added line #L172 was not covered by tests
nx.max(nx.abs(grad_Q)), nx.max(nx.abs(grad_R))
) # l-inf normalization

# We can parallelize the calculation of Q_new and R_new
Q_new = sinkhorn_unbalanced(

Check warning on line 177 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L177

Added line #L177 was not covered by tests
a=a,
b=g_Q,
M=grad_Q,
c=Q,
reg=1 / gamma_k,
reg_m=[float("inf"), tau],
method="sinkhorn_stabilized",
)

R_new = sinkhorn_unbalanced(

Check warning on line 187 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L187

Added line #L187 was not covered by tests
a=b,
b=g_R,
M=grad_R,
c=R,
reg=1 / gamma_k,
reg_m=[float("inf"), tau],
method="sinkhorn_stabilized",
)

g_Q = nx.dot(Q_new.T, ones_n)
g_R = nx.dot(R_new.T, ones_m)

Check warning on line 198 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L197-L198

Added lines #L197 - L198 were not covered by tests

grad_T = _compute_gradient_T(Q_new, R_new, M, g_Q, g_R, nx) # Shape (r, r)

Check warning on line 200 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L200

Added line #L200 was not covered by tests

gamma_T = gamma / nx.max(nx.abs(grad_T))

Check warning on line 202 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L202

Added line #L202 was not covered by tests

T_new = sinkhorn_unbalanced(

Check warning on line 204 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L204

Added line #L204 was not covered by tests
M=grad_T,
a=g_Q,
b=g_R,
reg=1 / gamma_T,
c=T,
reg_m=[float("inf"), float("inf")],
method="sinkhorn_stabilized",
) # Shape (r, r)

X_new = nx.dot(nx.dot(nx.diag(1 / g_Q), T_new), nx.diag(1 / g_R)) # Shape (r,r)

Check warning on line 214 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L214

Added line #L214 was not covered by tests

if log:
print(f"iteration {i} ", nx.sum(M * nx.dot(nx.dot(Q_new, X_new), R_new.T)))

Check warning on line 217 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L216-L217

Added lines #L216 - L217 were not covered by tests

if (

Check warning on line 219 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L219

Added line #L219 was not covered by tests
_compute_distance(Q_new, R_new, T_new, Q, R, T, nx)
< gamma_k * gamma_k * stopThr
):
return nx.dot(nx.dot(Q_new, X_new), R_new.T) # Shape (n, m)

Check warning on line 223 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L223

Added line #L223 was not covered by tests

Q, R, T, X = Q_new, R_new, T_new, X_new

Check warning on line 225 in ot/low_rank/_factor_relaxation.py

View check run for this annotation

Codecov / codecov/patch

ot/low_rank/_factor_relaxation.py#L225

Added line #L225 was not covered by tests
Loading