Skip to content

Commit aa4f370

Browse files
Add optional install options with pip (#627)
* add extra optional dependency options * check for dependencies, more informative errors: da, dr * update RELEASES.md * update README.md * Update setup.py Co-authored-by: Rémi Flamary <[email protected]> * change filename to requirements_all in all files * woops some of those were for docs --------- Co-authored-by: Rémi Flamary <[email protected]>
1 parent 14c08ba commit aa4f370

File tree

9 files changed

+38
-12
lines changed

9 files changed

+38
-12
lines changed

.circleci/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
command: |
4747
python -m pip install --user --upgrade --progress-bar off pip
4848
python -m pip install --user -e .
49-
python -m pip install --user --upgrade --no-cache-dir --progress-bar off -r requirements.txt
49+
python -m pip install --user --upgrade --no-cache-dir --progress-bar off -r requirements_all.txt
5050
python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt
5151
python -m pip install --user --upgrade --progress-bar off ipython sphinx-gallery memory_profiler
5252
# python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler

.github/workflows/build_doc.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- name: Get Python running
2525
run: |
2626
python -m pip install --user --upgrade --progress-bar off pip
27-
python -m pip install --user --upgrade --progress-bar off -r requirements.txt
27+
python -m pip install --user --upgrade --progress-bar off -r requirements_all.txt
2828
python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt
2929
python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler
3030
python -m pip install --user -e .

.github/workflows/build_tests.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
- name: Install dependencies
3737
run: |
3838
python -m pip install --upgrade pip
39-
pip install -r requirements.txt
39+
pip install -r requirements_all.txt
4040
pip install pytest pytest-cov
4141
- name: Run tests
4242
run: |
@@ -108,7 +108,7 @@ jobs:
108108
- name: Install dependencies
109109
run: |
110110
python -m pip install --upgrade pip
111-
pip install -r requirements.txt
111+
pip install -r requirements_all.txt
112112
pip install pytest
113113
- name: Run tests
114114
run: |

README.md

+6
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ or get the very latest version by running:
113113
pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root)
114114
```
115115

116+
Optional dependencies may be installed with
117+
```console
118+
pip install POT[all]
119+
```
120+
Note that this installs `cvxopt`, which is licensed under GPL 3.0. Alternatively, if you cannot use GPL-licensed software, the specific optional dependencies may be installed individually, or per-submodule. The available optional installations are `backend-jax, backend-tf, backend-torch, cvxopt, dr, gnn, all`.
121+
116122
#### Anaconda installation with conda-forge
117123

118124
If you use the Anaconda python distribution, POT is available in [conda-forge](https://conda-forge.org). To install it and the required dependencies:

RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
+ New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620)
1111
+ Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605).
1212
+ Added support for [Low rank Gromov-Wasserstein](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf) with `ot.gromov.lowrank_gromov_wasserstein_samples` (PR #614)
13+
+ Optional dependencies may now be installed with `pip install POT[all]` The specific backends or submodules' dependencies may also be installed individually. The pip options are: `backend-jax, backend-tf, backend-torch, cvxopt, dr, gnn, all`. The installation of the `cupy` backend should be done with conda.
1314

1415
#### Closed issues
1516
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)

ot/da.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,10 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
376376
elif sim == 'knn':
377377
if sim_param is None:
378378
sim_param = 3
379-
380-
from sklearn.neighbors import kneighbors_graph
379+
try:
380+
from sklearn.neighbors import kneighbors_graph
381+
except ImportError:
382+
raise ValueError('scikit-learn must be installed to use knn similarity. Install with `$pip install scikit-learn`.')
381383

382384
sS = nx.from_numpy(kneighbors_graph(
383385
X=nx.to_numpy(xs), n_neighbors=int(sim_param)

ot/dr.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
# License: MIT License
1818

1919
from scipy import linalg
20-
import autograd.numpy as np
21-
from sklearn.decomposition import PCA
22-
23-
import pymanopt
24-
import pymanopt.manifolds
25-
import pymanopt.optimizers
20+
try:
21+
import autograd.numpy as np
22+
from sklearn.decomposition import PCA
23+
24+
import pymanopt
25+
import pymanopt.manifolds
26+
import pymanopt.optimizers
27+
except ImportError:
28+
raise ImportError("Missing dependency for ot.dr. Requires autograd, pymanopt, scikit-learn. You can install with install with 'pip install POT[dr]', or 'conda install autograd pymanopt scikit-learn'")
2629

2730
from .bregman import sinkhorn as sinkhorn_bregman
2831
from .utils import dist as dist_utils, check_random_state
File renamed without changes.

setup.py

+14
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
sdk_path = subprocess.check_output(['xcrun', '--show-sdk-path'])
4747
os.environ['CFLAGS'] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8"))
4848

49+
with open('requirements_all.txt') as f:
50+
optional_requirements = f.read().splitlines()
51+
4952
setup(
5053
name='POT',
5154
version=__version__,
@@ -70,6 +73,17 @@
7073
scripts=[],
7174
data_files=[],
7275
install_requires=["numpy>=1.16", "scipy>=1.6"],
76+
extras_require={
77+
'backend-numpy': [], # in requirements.
78+
'backend-jax': ['jax<=0.4.24', 'jaxlib<=0.4.24'],
79+
'backend-cupy': [], # should be installed with conda, not pip, or figure out what CUDA version above.
80+
'backend-tf': ['tensorflow'],
81+
'backend-torch': ['torch'],
82+
'cvxopt': ['cvxopt'], # on it's own to prevent accidental GPL violations
83+
'dr': ['scikit-learn', 'pymanopt', 'autograd'],
84+
'gnn': ['torch', 'torch_geometric'],
85+
'all': optional_requirements
86+
},
7387
python_requires=">=3.7",
7488
classifiers=[
7589
'Development Status :: 5 - Production/Stable',

0 commit comments

Comments
 (0)