Skip to content

Commit 8f24cb9

Browse files
committed
add check_number_threads to ot/lp/__init__.py __all__
1 parent f268515 commit 8f24cb9

File tree

3 files changed

+27
-25
lines changed

3 files changed

+27
-25
lines changed

ot/lp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
free_support_barycenter,
1717
generalized_free_support_barycenter,
1818
)
19+
from ..utils import check_number_threads
1920

2021
# import compiled emd
2122
from .emd_wrap import emd_1d_sorted
@@ -44,4 +45,5 @@
4445
"semidiscrete_wasserstein2_unif_circle",
4546
"dmmot_monge_1dgrid_loss",
4647
"dmmot_monge_1dgrid_optimize",
48+
"check_number_threads",
4749
]

ot/lp/network_simplex.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,11 @@
1111
import numpy as np
1212
import warnings
1313

14-
from ..utils import list_to_array
14+
from ..utils import list_to_array, check_number_threads
1515
from ..backend import get_backend
1616
from .emd_wrap import emd_c, check_result
1717

1818

19-
def check_number_threads(numThreads):
20-
"""Checks whether or not the requested number of threads has a valid value.
21-
22-
Parameters
23-
----------
24-
numThreads : int or str
25-
The requested number of threads, should either be a strictly positive integer or "max" or None
26-
27-
Returns
28-
-------
29-
numThreads : int
30-
Corrected number of threads
31-
"""
32-
if (numThreads is None) or (
33-
isinstance(numThreads, str) and numThreads.lower() == "max"
34-
):
35-
return -1
36-
if (not isinstance(numThreads, int)) or numThreads < 1:
37-
raise ValueError(
38-
'numThreads should either be "max" or a strictly positive integer'
39-
)
40-
return numThreads
41-
42-
4319
def center_ot_dual(alpha0, beta0, a=None, b=None):
4420
r"""Center dual OT potentials w.r.t. their weights
4521

ot/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,3 +1341,27 @@ def proj_SDP(S, nx=None, vmin=0.0):
13411341
Q = nx.einsum("ijk,ik->ijk", P, w) # Q[i] = P[i] @ diag(w[i])
13421342
# R[i] = Q[i] @ P[i].T
13431343
return nx.einsum("ijk,ikl->ijl", Q, nx.transpose(P, (0, 2, 1)))
1344+
1345+
1346+
def check_number_threads(numThreads):
1347+
"""Checks whether or not the requested number of threads has a valid value.
1348+
1349+
Parameters
1350+
----------
1351+
numThreads : int or str
1352+
The requested number of threads, should either be a strictly positive integer or "max" or None
1353+
1354+
Returns
1355+
-------
1356+
numThreads : int
1357+
Corrected number of threads
1358+
"""
1359+
if (numThreads is None) or (
1360+
isinstance(numThreads, str) and numThreads.lower() == "max"
1361+
):
1362+
return -1
1363+
if (not isinstance(numThreads, int)) or numThreads < 1:
1364+
raise ValueError(
1365+
'numThreads should either be "max" or a strictly positive integer'
1366+
)
1367+
return numThreads

0 commit comments

Comments
 (0)