Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pause receiving while submitting tasks #534

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion ipyparallel/client/asyncresult.py
Original file line number Diff line number Diff line change
@@ -84,7 +84,7 @@ def __init__(

self._return_exceptions = return_exceptions

if isinstance(children[0], string_types):
if children and isinstance(children[0], string_types):
self.msg_ids = children
self._children = []
else:
@@ -96,6 +96,15 @@ def __init__(
self._targets = targets
self.owner = owner

if not children:
# empty result!
self._ready = True
self._success = True
f = Future()
f.set_result([])
self._resolve_result(f)
return

self._ready = False
self._ready_event = Event()
self._output_ready = False
83 changes: 73 additions & 10 deletions ipyparallel/client/client.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
import warnings
from collections.abc import Iterable
from concurrent.futures import Future
from contextlib import contextmanager
from getpass import getpass
from pprint import pprint
from threading import current_thread
@@ -990,21 +991,59 @@ def _stop_io_thread(self):
self._io_thread.join()

def _setup_streams(self):
self._query_stream = ZMQStream(self._query_socket, self._io_loop)
self._query_stream.on_recv(self._dispatch_single_reply, copy=False)
self._control_stream = ZMQStream(self._control_socket, self._io_loop)
self._streams = [] # all streams
self._engine_streams = [] # streams that talk to engines
self._query_stream = s = ZMQStream(self._query_socket, self._io_loop)
self._streams.append(s)
self._notification_stream = s = ZMQStream(
self._notification_socket, self._io_loop
)
self._streams.append(s)

self._control_stream = s = ZMQStream(self._control_socket, self._io_loop)
self._streams.append(s)
self._engine_streams.append(s)
self._mux_stream = s = ZMQStream(self._mux_socket, self._io_loop)
self._streams.append(s)
self._engine_streams.append(s)
self._task_stream = s = ZMQStream(self._task_socket, self._io_loop)
self._streams.append(s)
self._engine_streams.append(s)
self._broadcast_stream = s = ZMQStream(self._broadcast_socket, self._io_loop)
self._streams.append(s)
self._engine_streams.append(s)
self._iopub_stream = s = ZMQStream(self._iopub_socket, self._io_loop)
self._streams.append(s)
self._engine_streams.append(s)
self._start_receiving(all=True)

def _start_receiving(self, all=False):
"""Start receiving on streams
default: only engine streams
if all: include hub streams
"""
if all:
self._query_stream.on_recv(self._dispatch_single_reply, copy=False)
self._notification_stream.on_recv(self._dispatch_notification, copy=False)
self._control_stream.on_recv(self._dispatch_single_reply, copy=False)
self._mux_stream = ZMQStream(self._mux_socket, self._io_loop)
self._mux_stream.on_recv(self._dispatch_reply, copy=False)
self._task_stream = ZMQStream(self._task_socket, self._io_loop)
self._task_stream.on_recv(self._dispatch_reply, copy=False)
self._iopub_stream = ZMQStream(self._iopub_socket, self._io_loop)
self._broadcast_stream.on_recv(self._dispatch_reply, copy=False)
self._iopub_stream.on_recv(self._dispatch_iopub, copy=False)
self._notification_stream = ZMQStream(self._notification_socket, self._io_loop)
self._notification_stream.on_recv(self._dispatch_notification, copy=False)

self._broadcast_stream = ZMQStream(self._broadcast_socket, self._io_loop)
self._broadcast_stream.on_recv(self._dispatch_reply, copy=False)
def _stop_receiving(self, all=False):
"""Stop receiving on engine streams
If all: include hub streams
"""
if all:
streams = self._streams
else:
streams = self._engine_streams
for s in streams:
s.stop_on_recv()

def _start_io_thread(self):
"""Start IOLoop in a background thread."""
@@ -1034,6 +1073,30 @@ def _io_main(self, start_evt=None):
self._io_loop.start()
self._io_loop.close()

@contextmanager
def _pause_results(self):
"""Context manager to pause receiving results
When submitting lots of tasks,
the arrival of results can disrupt the processing
of new submissions.
Threadsafe.
"""
f = Future()

def _stop():
self._stop_receiving()
f.set_result(None)

# use add_callback to make it threadsafe
self._io_loop.add_callback(_stop)
f.result()
try:
yield
finally:
self._io_loop.add_callback(self._start_receiving)

@unpack_message
def _dispatch_single_reply(self, msg):
"""Dispatch single (non-execution) replies"""
4 changes: 4 additions & 0 deletions ipyparallel/client/map.py
Original file line number Diff line number Diff line change
@@ -66,6 +66,8 @@ def joinPartitions(self, listOfPartitions):
return self.concatenate(listOfPartitions)

def concatenate(self, listOfPartitions):
if len(listOfPartitions) == 0:
return listOfPartitions
testObject = listOfPartitions[0]
# First see if we have a known array type
if is_array(testObject):
@@ -88,6 +90,8 @@ def getPartition(self, seq, p, q, n=None):
return seq[p:n:q]

def joinPartitions(self, listOfPartitions):
if len(listOfPartitions) == 0:
return listOfPartitions
testObject = listOfPartitions[0]
# First see if we have a known array type
if is_array(testObject):
12 changes: 11 additions & 1 deletion ipyparallel/client/remotefunction.py
Original file line number Diff line number Diff line change
@@ -244,7 +244,17 @@ def __call__(self, *sequences, **kwargs):

if maxlen == 0:
# nothing to iterate over
return []
if self.block:
return []
else:
return AsyncMapResult(
self.view.client,
[],
self.mapObject,
fname=getname(self.func),
ordered=self.ordered,
return_exceptions=self.return_exceptions,
)

# check that the length of sequences match
if not _mapping and minlen != maxlen:
44 changes: 30 additions & 14 deletions ipyparallel/client/view.py
Original file line number Diff line number Diff line change
@@ -578,11 +578,12 @@ def _really_apply(
pargs = [PrePickled(arg) for arg in args]
pkwargs = {k: PrePickled(v) for k, v in kwargs.items()}

for ident in _idents:
future = self.client.send_apply_request(
self._socket, pf, pargs, pkwargs, track=track, ident=ident
)
futures.append(future)
with self.client._pause_results():
for ident in _idents:
future = self.client.send_apply_request(
self._socket, pf, pargs, pkwargs, track=track, ident=ident
)
futures.append(future)
if track:
trackers = [_.tracker for _ in futures]
else:
@@ -641,9 +642,16 @@ def map(self, f, *sequences, block=None, track=False, return_exceptions=False):

assert len(sequences) > 0, "must have some sequences to map onto!"
pf = ParallelFunction(
self, f, block=block, track=track, return_exceptions=return_exceptions
self, f, block=False, track=track, return_exceptions=return_exceptions
)
return pf.map(*sequences)
with self.client._pause_results():
ar = pf.map(*sequences)
if block:
try:
return ar.get()
except KeyboardInterrupt:
return ar
return ar

@sync_results
@save_ids
@@ -665,11 +673,12 @@ def execute(self, code, silent=True, targets=None, block=None):

_idents, _targets = self.client._build_targets(targets)
futures = []
for ident in _idents:
future = self.client.send_execute_request(
self._socket, code, silent=silent, ident=ident
)
futures.append(future)
with self.client._pause_results():
for ident in _idents:
future = self.client.send_execute_request(
self._socket, code, silent=silent, ident=ident
)
futures.append(future)
if isinstance(targets, int):
futures = futures[0]
ar = AsyncResult(
@@ -1292,12 +1301,19 @@ def map(
pf = ParallelFunction(
self,
f,
block=block,
block=False,
chunksize=chunksize,
ordered=ordered,
return_exceptions=return_exceptions,
)
return pf.map(*sequences)
with self.client._pause_results():
ar = pf.map(*sequences)
if block:
try:
return ar.get()
except KeyboardInterrupt:
return ar
return ar

def imap(
self,