Skip to content

Commit bbf2b79

Browse files
authored
gromov-wasserstein (#19)
* add gromov-wasserstein * add docstrings for gromov-wasserstein * update to LTS * bump CI Julia version
1 parent 9afa6af commit bbf2b79

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
fail-fast: false
2121
matrix:
2222
version:
23-
- '1.0'
23+
- '1.8'
2424
- '1'
2525
os:
2626
- ubuntu-latest

src/lib.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,3 +475,44 @@ See also: [`barycenter`](@ref)
475475
function barycenter_unbalanced(A, C, ε, λ; kwargs...)
476476
return pot.barycenter_unbalanced(A, C, ε, λ; kwargs...)
477477
end
478+
479+
"""
480+
gromov_wasserstein(μ, ν, Cμ, Cν, loss = "square_loss"; kwargs...)
481+
482+
Compute the exact Gromov-Wasserstein transport plan between `(μ, Cμ)` and `(ν, Cν)`.
483+
484+
The Gromov-Wasserstein transport problem seeks to find a minimizer of
485+
```math
486+
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, k, l} L((C_μ)_{ik}, (C_ν)_{jl}) \\gamma_{ij} \\gamma_{kl},
487+
```
488+
where ``L`` is quadratic (`loss = "square_loss"`) or the Kullback-Leibler divergence (`loss = "kl_loss"`).
489+
490+
This function is a wrapper of the function
491+
[`gromov_wasserstein`](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.gromov_wasserstein) in the
492+
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
493+
Python function.
494+
"""
495+
function gromov_wasserstein(μ, ν, Cμ, Cν, loss="square_loss"; kwargs...)
496+
return pot.gromov.gromov_wasserstein(Cμ, Cν, μ, ν, loss; kwargs...)
497+
end
498+
499+
"""
500+
entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss = "square_loss"; kwargs...)
501+
502+
Compute the entropy-regularized Gromov-Wasserstein transport plan between `(μ, Cμ)` and `(ν, Cν)` with parameter `ε`.
503+
504+
The entropy-regularized Gromov-Wasserstein transport problem seeks to find a minimizer of
505+
```math
506+
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, k, l} L((C_μ)_{ik}, (C_ν)_{jl}) \\gamma_{ij} \\gamma_{kl} + ε \\Omega(\\gamma),
507+
```
508+
where ``L`` is quadratic (`loss = "square_loss"`) or the Kullback-Leibler divergence (`loss = "kl_loss"`)
509+
and ``\\Omega(\\gamma) = \\sum_{ij} \\gamma_{ij} \\log(\\gamma_{ij})`` is the entropic regularization term.
510+
511+
This function is a wrapper of the function
512+
[`entropic_gromov_wasserstein`](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.entropic_gromov_wasserstein) in the
513+
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
514+
Python function.
515+
"""
516+
function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; kwargs...)
517+
return pot.gromov.entropic_gromov_wasserstein(Cμ, Cν, μ, ν, loss, ε; kwargs...)
518+
end

0 commit comments

Comments
 (0)