Skip to content

Commit f8d871e

Browse files
authored
[MRG] Tensorflow backend & Benchmarker & Myst_parser (#316)
* First batch of tf methods (to be continued) * Second batch of method (yet to debug) * tensorflow for cpu * add tf requirement * pep8 + bug * small changes * attempt to solve pymanopt bug with tf2 * attempt #2 * attempt #3 * attempt 4 * docstring * correct pep8 violation introduced in merge conflicts resolution * attempt 5 * attempt 6 * just a random try * Revert "just a random try" This reverts commit 8223e76. * GPU tests for tensorflow * pep8 * attempt to solve issue with m2r2 * Remove transpose backend method * first draft of benchmarker (need to correct time measurement) * prettier bench table * Bitsize and prettier device methods * prettified table bench * Bug corrected (results were mixed up in the final table) * Better perf counter (for GPU support) * pep8 * EMD bench * solve bug if no GPU available * pep8 * warning about tensorflow numpy api being required in the backend.py docstring * Bug solve in backend docstring * not covering code which requires a GPU * Tensorflow gradients manipulation tested * Number of warmup runs is now customizable * typo * Remove some warnings while building docs * Change prettier_device to device_type in backend * Correct JAX mistakes preventing to see the CPU if a GPU is present * Attempt to solve JAX bug in case no GPU is found * Reworked benchmarks order and results storage & clear GPU after usage by benchmark * Add bench to backend docstring * better benchs * remove useless stuff * Better device_type * Now using MYST_PARSER and solving links issue in the README.md / online docs
1 parent b3dc68f commit f8d871e

30 files changed

+1161
-97
lines changed

.github/requirements_test_windows.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ cython
44
matplotlib
55
autograd
66
pymanopt==0.2.4; python_version <'3'
7-
pymanopt; python_version >= '3'
7+
pymanopt==0.2.6rc1; python_version >= '3'
88
cvxopt
99
scikit-learn
1010
pytest

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ POT provides the following generic OT solvers (links to examples):
3535
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
3636
formulations).
3737
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
38-
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays.
38+
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
3939

4040
POT provides the following Machine Learning related solvers:
4141

@@ -202,12 +202,12 @@ This toolbox benefit a lot from open source research and we would like to thank
202202

203203
* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab)
204204
* [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT)
205-
* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) ( C++ code for EMD)
205+
* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) (C++ code for EMD)
206206
* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda)
207207

208208
## Contributions and code of conduct
209209

210-
Every contribution is welcome and should respect the [contribution guidelines](https://pythonot.github.io/contributing.html). Each member of the project is expected to follow the [code of conduct](https://pythonot.github.io/code_of_conduct.html).
210+
Every contribution is welcome and should respect the [contribution guidelines](.github/CONTRIBUTING.md). Each member of the project is expected to follow the [code of conduct](.github/CODE_OF_CONDUCT.md).
211211

212212
## Support
213213

@@ -217,7 +217,7 @@ You can ask questions and join the development discussion:
217217
* On the POT [gitter channel](https://gitter.im/PythonOT/community)
218218
* On the POT [mailing list](https://mail.python.org/mm3/mailman3/lists/pot.python.org/)
219219

220-
You can also post bug reports and feature requests in Github issues. Make sure to read our [guidelines](https://pythonot.github.io/contributing.html) first.
220+
You can also post bug reports and feature requests in Github issues. Make sure to read our [guidelines](.github/CONTRIBUTING.md) first.
221221

222222
## References
223223

benchmarks/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from . import benchmark
2+
from . import sinkhorn_knopp
3+
from . import emd
4+
5+
__all__= ["benchmark", "sinkhorn_knopp", "emd"]

benchmarks/benchmark.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# /usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
from ot.backend import get_backend_list, jax, tf
5+
import gc
6+
7+
8+
def setup_backends():
9+
if jax:
10+
from jax.config import config
11+
config.update("jax_enable_x64", True)
12+
13+
if tf:
14+
from tensorflow.python.ops.numpy_ops import np_config
15+
np_config.enable_numpy_behavior()
16+
17+
18+
def exec_bench(setup, tested_function, param_list, n_runs, warmup_runs):
19+
backend_list = get_backend_list()
20+
for i, nx in enumerate(backend_list):
21+
if nx.__name__ == "tf" and i < len(backend_list) - 1:
22+
# Tensorflow should be the last one to be benchmarked because
23+
# as far as I'm aware, there is no way to force it to release
24+
# GPU memory. Hence, if any other backend is benchmarked after
25+
# Tensorflow and requires the usage of a GPU, it will not have the
26+
# full memory available and you may have a GPU Out Of Memory error
27+
# even though your GPU can technically hold your tensors in memory.
28+
backend_list.pop(i)
29+
backend_list.append(nx)
30+
break
31+
32+
inputs = [setup(param) for param in param_list]
33+
results = dict()
34+
for nx in backend_list:
35+
for i in range(len(param_list)):
36+
print(nx, param_list[i])
37+
args = inputs[i]
38+
results_nx = nx._bench(
39+
tested_function,
40+
*args,
41+
n_runs=n_runs,
42+
warmup_runs=warmup_runs
43+
)
44+
gc.collect()
45+
results_nx_with_param_in_key = dict()
46+
for key in results_nx:
47+
new_key = (param_list[i], *key)
48+
results_nx_with_param_in_key[new_key] = results_nx[key]
49+
results.update(results_nx_with_param_in_key)
50+
return results
51+
52+
53+
def convert_to_html_table(results, param_name, main_title=None, comments=None):
54+
string = "<table>\n"
55+
keys = list(results.keys())
56+
params, names, devices, bitsizes = zip(*keys)
57+
58+
devices_names = sorted(list(set(zip(devices, names))))
59+
params = sorted(list(set(params)))
60+
bitsizes = sorted(list(set(bitsizes)))
61+
length = len(devices_names) + 1
62+
cpus_cols = list(devices).count("CPU") / len(bitsizes) / len(params)
63+
gpus_cols = list(devices).count("GPU") / len(bitsizes) / len(params)
64+
assert cpus_cols + gpus_cols == len(devices_names)
65+
66+
if main_title is not None:
67+
string += f'<tr><th align="center" colspan="{length}">{str(main_title)}</th></tr>\n'
68+
69+
for i, bitsize in enumerate(bitsizes):
70+
71+
if i != 0:
72+
string += f'<tr><td colspan="{length}">&nbsp;</td></tr>\n'
73+
74+
# make bitsize header
75+
text = f"{bitsize} bits"
76+
if comments is not None:
77+
text += " - "
78+
if isinstance(comments, (tuple, list)) and len(comments) == len(bitsizes):
79+
text += str(comments[i])
80+
else:
81+
text += str(comments)
82+
string += f'<tr><th align="center">Bitsize</th>'
83+
string += f'<th align="center" colspan="{length - 1}">{text}</th></tr>\n'
84+
85+
# make device header
86+
string += f'<tr><th align="center">Device</th>'
87+
string += f'<th align="center" colspan="{cpus_cols}">CPU</th>'
88+
string += f'<th align="center" colspan="{gpus_cols}">GPU</th></tr>\n'
89+
90+
# make param_name / backend header
91+
string += f'<tr><th align="center">{param_name}</th>'
92+
for device, name in devices_names:
93+
string += f'<th align="center">{name}</th>'
94+
string += "</tr>\n"
95+
96+
# make results rows
97+
for param in params:
98+
string += f'<tr><td align="center">{param}</td>'
99+
for device, name in devices_names:
100+
key = (param, name, device, bitsize)
101+
string += f'<td align="center">{results[key]:.4f}</td>'
102+
string += "</tr>\n"
103+
104+
string += "</table>"
105+
return string

benchmarks/emd.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# /usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
import numpy as np
5+
import ot
6+
from .benchmark import (
7+
setup_backends,
8+
exec_bench,
9+
convert_to_html_table
10+
)
11+
12+
13+
def setup(n_samples):
14+
rng = np.random.RandomState(789465132)
15+
x = rng.randn(n_samples, 2)
16+
y = rng.randn(n_samples, 2)
17+
18+
a = ot.utils.unif(n_samples)
19+
M = ot.dist(x, y)
20+
return a, M
21+
22+
23+
if __name__ == "__main__":
24+
n_runs = 100
25+
warmup_runs = 10
26+
param_list = [50, 100, 500, 1000, 2000, 5000]
27+
28+
setup_backends()
29+
results = exec_bench(
30+
setup=setup,
31+
tested_function=lambda a, M: ot.emd(a, a, M),
32+
param_list=param_list,
33+
n_runs=n_runs,
34+
warmup_runs=warmup_runs
35+
)
36+
print(convert_to_html_table(
37+
results,
38+
param_name="Sample size",
39+
main_title=f"EMD - Averaged on {n_runs} runs"
40+
))

benchmarks/sinkhorn_knopp.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# /usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
import numpy as np
5+
import ot
6+
from .benchmark import (
7+
setup_backends,
8+
exec_bench,
9+
convert_to_html_table
10+
)
11+
12+
13+
def setup(n_samples):
14+
rng = np.random.RandomState(123456789)
15+
a = rng.rand(n_samples // 4, 100)
16+
b = rng.rand(n_samples, 100)
17+
18+
wa = ot.unif(n_samples // 4)
19+
wb = ot.unif(n_samples)
20+
21+
M = ot.dist(a.copy(), b.copy())
22+
return wa, wb, M
23+
24+
25+
if __name__ == "__main__":
26+
n_runs = 100
27+
warmup_runs = 10
28+
param_list = [50, 100, 500, 1000, 2000, 5000]
29+
30+
setup_backends()
31+
results = exec_bench(
32+
setup=setup,
33+
tested_function=lambda *args: ot.bregman.sinkhorn(*args, reg=1, stopThr=1e-7),
34+
param_list=param_list,
35+
n_runs=n_runs,
36+
warmup_runs=warmup_runs
37+
)
38+
print(convert_to_html_table(
39+
results,
40+
param_name="Sample size",
41+
main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs"
42+
))

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ numpydoc
44
memory_profiler
55
pillow
66
networkx
7-
m2r2
7+
myst-parser

docs/requirements_rtd.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ numpydoc
33
memory_profiler
44
pillow
55
networkx
6-
m2r2
6+
myst-parser
77
numpy
88
scipy>=1.0
99
cython
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Code of Conduct
2+
===============
3+
4+
.. include:: ../../../.github/CODE_OF_CONDUCT.md
5+
:parser: myst_parser.sphinx_
6+
:start-line: 2

docs/source/.github/CONTRIBUTING.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Contributing to POT
2+
===================
3+
4+
.. include:: ../../../.github/CONTRIBUTING.md
5+
:parser: myst_parser.sphinx_
6+
:start-line: 3

docs/source/code_of_conduct.rst

Lines changed: 0 additions & 1 deletion
This file was deleted.

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __getattr__(cls, name):
6969
'sphinx.ext.viewcode',
7070
'sphinx.ext.napoleon',
7171
'sphinx_gallery.gen_gallery',
72-
'm2r2'
72+
'myst_parser'
7373
]
7474

7575
autosummary_generate = True

docs/source/contributing.rst

Lines changed: 0 additions & 1 deletion
This file was deleted.

docs/source/index.rst

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ Contents
1717
all
1818
auto_examples/index
1919
releases
20-
contributing
21-
Code of Conduct <code_of_conduct>
22-
23-
.. mdinclude:: ../../README.md
24-
:start-line: 2
20+
.github/CONTRIBUTING
21+
.github/CODE_OF_CONDUCT
2522

23+
.. include:: ../../README.md
24+
:parser: myst_parser.sphinx_
2625

2726

2827
Indices and tables

0 commit comments

Comments
 (0)