Skip to content

Commit b06cc02

Browse files
authored
Add files from OptimalTransport.jl (#1)
1 parent de29be0 commit b06cc02

File tree

9 files changed

+210
-17
lines changed

9 files changed

+210
-17
lines changed

.github/workflows/CI.yml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
test:
11-
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
11+
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ (matrix.python && 'system Python') || 'conda' }}
1212
runs-on: ${{ matrix.os }}
1313
strategy:
1414
fail-fast: false
@@ -23,13 +23,27 @@ jobs:
2323
- windows-latest
2424
arch:
2525
- x64
26+
python:
27+
- ''
28+
- python
2629
include:
2730
- version: '1'
2831
os: ubuntu-latest
2932
arch: x64
3033
coverage: true
3134
steps:
3235
- uses: actions/checkout@v2
36+
- name: Install python
37+
uses: actions/setup-python@v2
38+
with:
39+
python-version: '3.x'
40+
architecture: ${{ matrix.arch }}
41+
if: matrix.python
42+
# Limitation of pip: https://pythonot.github.io/index.html#pip-installation
43+
- run: python -m pip install cython numpy
44+
if: matrix.python
45+
- run: python -m pip install pot
46+
if: matrix.python
3347
- uses: julia-actions/setup-julia@v1
3448
with:
3549
version: ${{ matrix.version }}
@@ -45,7 +59,11 @@ jobs:
4559
${{ runner.os }}-test-
4660
${{ runner.os }}-
4761
- uses: julia-actions/julia-buildpkg@v1
62+
env:
63+
PYTHON: ${{ matrix.python }}
4864
- uses: julia-actions/julia-runtest@v1
65+
env:
66+
PYTHON: ${{ matrix.python }}
4967
- uses: julia-actions/julia-processcoverage@v1
5068
if: matrix.coverage
5169
- uses: codecov/codecov-action@v1

.github/workflows/Docs.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ jobs:
2121
using Pkg
2222
Pkg.develop(PackageSpec(path=pwd()))
2323
Pkg.instantiate()'
24+
env:
25+
PYTHON: ''
2426
- run: julia --project=docs docs/make.jl
2527
env:
2628
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
2729
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
30+
PYTHON: ''

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
33
authors = ["David Widmann"]
44
version = "0.1.0"
55

6+
[deps]
7+
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
8+
69
[compat]
10+
PyCall = "1"
711
julia = "1"
812

913
[extras]

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# POT.jl
22

3-
Julia interface for the Python Optimal Transport (POT) library
3+
*Julia interface for the [Python Optimal Transport (POT) package](https://pythonot.github.io/)*
44

55
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://devmotion.github.io/POT.jl/stable)
66
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://devmotion.github.io/POT.jl/dev)
77
[![Build Status](https://github.com/devmotion/POT.jl/workflows/CI/badge.svg)](https://github.com/devmotion/POT.jl/actions)
88
[![Coverage](https://codecov.io/gh/devmotion/POT.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/devmotion/POT.jl)
99
[![Coverage](https://coveralls.io/repos/github/devmotion/POT.jl/badge.svg?branch=master)](https://coveralls.io/github/devmotion/POT.jl?branch=master)
1010
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
11+
12+
This package was originally part of [OptimalTransport.jl](https://github.com/zsteve/OptimalTransport.jl).

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ makedocs(;
1919
canonical="https://devmotion.github.io/POT.jl",
2020
assets=String[],
2121
),
22-
pages=["Home" => "index.md"],
22+
pages=["Home" => "index.md", "api.md"],
2323
strict=true,
2424
checkdocs=:exports,
2525
)

docs/src/api.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# API
2+
3+
## Exact optimal transport (Kantorovich) problem
4+
5+
```@docs
6+
emd
7+
emd2
8+
```
9+
10+
## Entropically regularised optimal transport
11+
12+
```@docs
13+
sinkhorn
14+
sinkhorn2
15+
```
16+
17+
## Unbalanced optimal transport
18+
19+
```@docs
20+
sinkhorn_unbalanced
21+
sinkhorn_unbalanced2
22+
```

docs/src/index.md

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,3 @@
1-
```@meta
2-
CurrentModule = POT
3-
```
1+
# POT.jl
42

5-
# POT
6-
7-
Documentation for [POT](https://github.com/devmotion/POT.jl).
8-
9-
```@index
10-
```
11-
12-
```@autodocs
13-
Modules = [POT]
14-
```
3+
*Julia interface for the [Python Optimal Transport (POT) package](https://pythonot.github.io/)*

src/POT.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
module POT
22

3-
# Write your package code here.
3+
using PyCall: PyCall
4+
5+
export emd, emd2, sinkhorn, sinkhorn2, sinkhorn_unbalanced, sinkhorn_unbalanced2
6+
7+
const pot = PyCall.PyNULL()
8+
9+
include("lib.jl")
10+
11+
function __init__()
12+
return copy!(pot, PyCall.pyimport_conda("ot", "pot", "conda-forge"))
13+
end
414

515
end

src/lib.jl

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
emd(mu, nu, C)
3+
4+
Compute transport map for Monge-Kantorovich problem with source and target marginals `mu`
5+
and `nu` and a cost matrix `C` of dimensions `(length(mu), length(nu))`.
6+
7+
Return optimal transport coupling `γ` of the same dimensions as `C` which solves
8+
9+
```math
10+
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle
11+
```
12+
13+
This function is a wrapper of the function
14+
[`emd`](https://pythonot.github.io/all.html#ot.emd) in the Python Optimal Transport
15+
package.
16+
"""
17+
function emd(mu, nu, C)
18+
return pot.lp.emd(nu, mu, PyCall.PyReverseDims(C))'
19+
end
20+
21+
"""
22+
emd2(mu, nu, C)
23+
24+
Compute exact transport cost for Monge-Kantorovich problem with source and target marginals
25+
`mu` and `nu` and a cost matrix `C` of dimensions `(length(mu), length(nu))`.
26+
27+
Returns optimal transport cost (a scalar), i.e. the optimal value
28+
29+
```math
30+
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle
31+
```
32+
33+
This function is a wrapper of the function
34+
[`emd2`](https://pythonot.github.io/all.html#ot.emd2) in the Python Optimal Transport
35+
package.
36+
"""
37+
function emd2(mu, nu, C)
38+
return pot.lp.emd2(nu, mu, PyCall.PyReverseDims(C))[1]
39+
end
40+
41+
"""
42+
sinkhorn(mu, nu, C, eps; tol=1e-9, max_iter = 1000, method = "sinkhorn", verbose = false)
43+
44+
Compute optimal transport map of histograms `mu` and `nu` with cost matrix `C` and entropic
45+
regularization parameter `eps`.
46+
47+
Method can be a choice of `"sinkhorn"`, `"greenkhorn"`, `"sinkhorn_stabilized"`, or
48+
`"sinkhorn_epsilon_scaling"` (Flamary et al., 2017).
49+
50+
This function is a wrapper of the function
51+
[`sinkhorn`](https://pythonot.github.io/all.html?highlight=sinkhorn#ot.sinkhorn) in the
52+
Python Optimal Transport package.
53+
"""
54+
function sinkhorn(mu, nu, C, eps; tol=1e-9, max_iter=1000, method="sinkhorn", verbose=false)
55+
return pot.sinkhorn(
56+
nu,
57+
mu,
58+
PyCall.PyReverseDims(C),
59+
eps;
60+
stopThr=tol,
61+
numItermax=max_iter,
62+
method=method,
63+
verbose=verbose,
64+
)'
65+
end
66+
67+
"""
68+
sinkhorn2(mu, nu, C, eps; tol=1e-9, max_iter = 1000, method = "sinkhorn", verbose = false)
69+
70+
Compute optimal transport cost of histograms `mu` and `nu` with cost matrix `C` and
71+
entropic regularization parameter `eps`.
72+
73+
Method can be a choice of `"sinkhorn"`, `"greenkhorn"`, `"sinkhorn_stabilized"`, or
74+
`"sinkhorn_epsilon_scaling"` (Flamary et al., 2017).
75+
76+
This function is a wrapper of the function
77+
[`sinkhorn2`](https://pythonot.github.io/all.html?highlight=sinkhorn#ot.sinkhorn2) in the
78+
Python Optimal Transport package.
79+
"""
80+
function sinkhorn2(
81+
mu, nu, C, eps; tol=1e-9, max_iter=1000, method="sinkhorn", verbose=false
82+
)
83+
return pot.sinkhorn2(
84+
nu,
85+
mu,
86+
PyCall.PyReverseDims(C),
87+
eps;
88+
stopThr=tol,
89+
numItermax=max_iter,
90+
method=method,
91+
verbose=verbose,
92+
)[1]
93+
end
94+
95+
"""
96+
sinkhorn_unbalanced(mu, nu, C, eps, lambda; tol = 1e-9, max_iter = 1000, method = "sinkhorn", verbose = false)
97+
98+
Compute optimal transport map of histograms `mu` and `nu` with cost matrix `C`, using
99+
entropic regularisation parameter `eps` and marginal weighting functions `lambda`.
100+
101+
This function is a wrapper of the function
102+
[`sinkhorn_unbalanced`](https://pythonot.github.io/all.html?highlight=sinkhorn_unbalanced#ot.sinkhorn_unbalanced)
103+
in the Python Optimal Transport package.
104+
"""
105+
function sinkhorn_unbalanced(
106+
mu, nu, C, eps, lambda; tol=1e-9, max_iter=1000, method="sinkhorn", verbose=false
107+
)
108+
return pot.sinkhorn_unbalanced(
109+
nu,
110+
mu,
111+
PyCall.PyReverseDims(C),
112+
eps,
113+
lambda;
114+
stopThr=tol,
115+
numItermax=max_iter,
116+
method=method,
117+
verbose=verbose,
118+
)'
119+
end
120+
121+
"""
122+
sinkhorn_unbalanced2(mu, nu, C, eps, lambda; tol = 1e-9, max_iter = 1000, method = "sinkhorn", verbose = false)
123+
124+
Compute optimal transport cost of histograms `mu` and `nu` with cost matrix `C`, using
125+
entropic regularisation parameter `eps` and marginal weighting functions `lambda`.
126+
127+
This function is a wrapper of the function
128+
[`sinkhorn_unbalanced2`](https://pythonot.github.io/all.html#ot.sinkhorn_unbalanced2) in
129+
the Python Optimal Transport package.
130+
"""
131+
function sinkhorn_unbalanced2(
132+
mu, nu, C, eps, lambda; tol=1e-9, max_iter=1000, method="sinkhorn", verbose=false
133+
)
134+
return pot.sinkhorn_unbalanced2(
135+
nu,
136+
mu,
137+
PyCall.PyReverseDims(C),
138+
eps,
139+
lambda;
140+
stopThr=tol,
141+
numItermax=max_iter,
142+
method=method,
143+
verbose=verbose,
144+
)[1]
145+
end

0 commit comments

Comments
 (0)