Skip to content

Commit 7ac1b46

Browse files
committed
cleanup parmap on windows
1 parent 0bc936f commit 7ac1b46

File tree

4 files changed

+57
-40
lines changed

4 files changed

+57
-40
lines changed

.travis.yml

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
11
dist: xenial # required for Python >= 3.7
22
language: python
33
matrix:
4-
# allow_failures:
5-
# - os: osx
6-
include:
7-
# - os: osx
8-
# language: generic
9-
- os: linux
10-
sudo: required
11-
python: 3.4
12-
- os: linux
13-
sudo: required
14-
python: 3.5
15-
- os: linux
16-
sudo: required
17-
python: 3.6
18-
- os: linux
19-
sudo: required
20-
python: 3.7
21-
- os: linux
22-
sudo: required
23-
python: 2.7
24-
- name: "Python 3.7.3 on Windows"
25-
os: windows # Windows 10.0.17134 N/A Build 17134
26-
language: shell # 'language: python' is an error on Travis CI Windows
27-
before_install: choco install python
28-
env: PATH=/c/Python37:/c/Python37/Scripts:$PATH
4+
allow_failures:
5+
- os: osx
6+
- os: windows
7+
include:
8+
- os: osx
9+
language: generic
10+
- os: linux
11+
sudo: required
12+
python: 3.4
13+
- os: linux
14+
sudo: required
15+
python: 3.5
16+
- os: linux
17+
sudo: required
18+
python: 3.6
19+
- os: linux
20+
sudo: required
21+
python: 3.7
22+
- os: linux
23+
sudo: required
24+
python: 2.7
25+
- name: "Python 3.7.3 on Windows"
26+
os: windows # Windows 10.0.17134 N/A Build 17134
27+
language: shell # 'language: python' is an error on Travis CI Windows
28+
before_install: choco install python
29+
env: PATH=/c/Python37:/c/Python37/Scripts:$PATH
2930
before_install:
3031
- ./.travis/before_install.sh
3132
# command to install dependencies

ot/lp/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# License: MIT License
1212

1313
import multiprocessing
14-
14+
import sys
1515
import numpy as np
1616
from scipy.sparse import coo_matrix
1717

@@ -151,6 +151,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
151151
Target histogram (uniform weight if empty list)
152152
M : (ns,nt) numpy.ndarray, float64
153153
Loss matrix (c-order array with type float64)
154+
processes : int, optional (default=nb cpu)
155+
Nb of processes used for multiple emd computation (not used on windows)
154156
numItermax : int, optional (default=100000)
155157
The maximum number of iterations before stopping the optimization
156158
algorithm if it has not converged.
@@ -200,6 +202,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
200202
b = np.asarray(b, dtype=np.float64)
201203
M = np.asarray(M, dtype=np.float64)
202204

205+
# problem with pikling Forks
206+
if sys.platform.endswith('win32'):
207+
processes=1
208+
203209
# if empty array given then use uniform distributions
204210
if len(a) == 0:
205211
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
@@ -228,7 +234,11 @@ def f(b):
228234
return f(b)
229235
nb = b.shape[1]
230236

231-
res = parmap(f, [b[:, i] for i in range(nb)], processes)
237+
if processes>1:
238+
res = parmap(f, [b[:, i] for i in range(nb)], processes)
239+
else:
240+
res = list(map(f, [b[:, i].copy() for i in range(nb)]))
241+
232242
return res
233243

234244

ot/utils.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,23 +214,28 @@ def fun(f, q_in, q_out):
214214

215215

216216
def parmap(f, X, nprocs=multiprocessing.cpu_count()):
217-
""" paralell map for multiprocessing """
218-
q_in = multiprocessing.Queue(1)
219-
q_out = multiprocessing.Queue()
217+
""" paralell map for multiprocessing (only map on windows)"""
220218

221-
proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
222-
for _ in range(nprocs)]
223-
for p in proc:
224-
p.daemon = True
225-
p.start()
219+
if not sys.platform.endswith('win32'):
226220

227-
sent = [q_in.put((i, x)) for i, x in enumerate(X)]
228-
[q_in.put((None, None)) for _ in range(nprocs)]
229-
res = [q_out.get() for _ in range(len(sent))]
221+
q_in = multiprocessing.Queue(1)
222+
q_out = multiprocessing.Queue()
230223

231-
[p.join() for p in proc]
224+
proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
225+
for _ in range(nprocs)]
226+
for p in proc:
227+
p.daemon = True
228+
p.start()
232229

233-
return [x for i, x in sorted(res)]
230+
sent = [q_in.put((i, x)) for i, x in enumerate(X)]
231+
[q_in.put((None, None)) for _ in range(nprocs)]
232+
res = [q_out.get() for _ in range(len(sent))]
233+
234+
[p.join() for p in proc]
235+
236+
return [x for i, x in sorted(res)]
237+
else:
238+
return list(map(f, X))
234239

235240

236241
def check_params(**kwargs):

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ matplotlib
55
sphinx-gallery
66
autograd
77
pymanopt
8+
cvxopt
89
pytest

0 commit comments

Comments
 (0)