Skip to content

Commit

Permalink
Merge pull request #874 from minrk/broadcast-map
Browse files Browse the repository at this point in the history
add BroadcastView.map
  • Loading branch information
minrk authored Mar 11, 2024
2 parents 9ff1800 + 7aa1d07 commit 142583e
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 28 deletions.
115 changes: 111 additions & 4 deletions ipyparallel/client/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from IPython import get_ipython
from traitlets import Any, Bool, CFloat, Dict, HasTraits, Instance, Integer, List, Set

import ipyparallel as ipp
from ipyparallel import util
from ipyparallel.controller.dependency import Dependency, dependent

Expand Down Expand Up @@ -767,7 +768,7 @@ def scatter(
mapObject = Map.dists[dist]()
nparts = len(targets)
futures = []
trackers = []
_lengths = []
for index, engineid in enumerate(targets):
partition = mapObject.getPartition(seq, index, nparts)
if flatten and len(partition) == 1:
Expand All @@ -777,10 +778,12 @@ def scatter(
r = self.push(ns, block=False, track=track, targets=engineid)
r.owner = False
futures.extend(r._children)
_lengths.append(len(partition))

r = AsyncResult(
self.client, futures, fname='scatter', targets=targets, owner=True
)
r._scatter_lengths = _lengths
if block:
r.wait()
else:
Expand Down Expand Up @@ -930,7 +933,6 @@ def _really_apply(
track = self.track if track is None else track
targets = self.targets if targets is None else targets
idents, _targets = self.client._build_targets(targets)
futures = []

pf = PrePickled(f)
pargs = [PrePickled(arg) for arg in args]
Expand Down Expand Up @@ -1014,8 +1016,113 @@ def make_asyncresult(message_future):
pass
return ar

def map(self, f, *sequences, **kwargs):
raise NotImplementedError("BroadcastView.map not yet implemented")
@staticmethod
def _broadcast_map(f, *sequence_names):
"""Function passed to apply
Equivalent, but account for the fact that scatter
occurs in a separate step.
Does these things:
- resolve sequence names to sequences in the user namespace
- collect list(map(f, *squences))
- cleanup temporary sequence variables from scatter
"""
sequences = []
ip = get_ipython()
for seq_name in sequence_names:
sequences.append(ip.user_ns.pop(seq_name))
return list(map(f, *sequences))

@_not_coalescing
def map(self, f, *sequences, block=None, track=False, return_exceptions=False):
"""Parallel version of builtin `map`, using this View's `targets`.
There will be one task per engine, so work will be chunked
if the sequences are longer than `targets`.
Results can be iterated as they are ready, but will become available in chunks.
.. note::
BroadcastView does not yet have a fully native map implementation.
In particular, the scatter step is still one message per engine,
identical to DirectView,
and typically slower due to the more complex scheduler.
It is more efficient to partition inputs via other means (e.g. SPMD based on rank & size)
and use `apply` to submit all tasks in one broadcast.
.. versionadded:: 8.8
Parameters
----------
f : callable
function to be mapped
*sequences : one or more sequences of matching length
the sequences to be distributed and passed to `f`
block : bool [default self.block]
whether to wait for the result or not
track : bool [default False]
Track underlying zmq send to indicate when it is safe to modify memory.
Only for zero-copy sends such as numpy arrays that are going to be modified in-place.
return_exceptions : bool [default False]
Return remote Exceptions in the result sequence instead of raising them.
Returns
-------
If block=False
An :class:`~ipyparallel.client.asyncresult.AsyncMapResult` instance.
An object like AsyncResult, but which reassembles the sequence of results
into a single list. AsyncMapResults can be iterated through before all
results are complete.
else
A list, the result of ``map(f,*sequences)``
"""
if block is None:
block = self.block
if track is None:
track = self.track

# unique identifier, since we're living in the interactive namespace
map_key = secrets.token_hex(5)
dist = 'b'
map_object = Map.dists[dist]()

seq_names = []
for i, seq in enumerate(sequences):
seq_name = f"_seq_{map_key}_{i}"
seq_names.append(seq_name)
try:
len(seq)
except Exception:
# cast length-less sequences (e.g. Range) to list
seq = list(seq)

ar = self.scatter(seq_name, seq, dist=dist, block=False, track=track)
scatter_chunk_sizes = ar._scatter_lengths

# submit the map tasks as an actual broadcast
ar = self.apply(self._broadcast_map, f, *seq_names)
ar.owner = False
# re-wrap messages in an AsyncMapResult to get map API
# this is where the 'gather' reconstruction happens
amr = ipp.AsyncMapResult(
self.client,
ar._children,
map_object,
fname=getname(f),
return_exceptions=return_exceptions,
chunk_sizes={
future.msg_id: chunk_size
for future, chunk_size in zip(ar._children, scatter_chunk_sizes)
},
)

if block:
return amr.get()
else:
return amr

# scatter/gather cannot be coalescing yet
scatter = _not_coalescing(DirectView.scatter)
Expand Down
24 changes: 0 additions & 24 deletions ipyparallel/tests/test_view_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,6 @@ def teardown_method(self):
if not self._broadcast_view_used:
pytest.skip("No broadcast view used")

@needs_map
def test_map(self):
pass

@needs_map
def test_map_ref(self):
pass

@needs_map
def test_map_reference(self):
pass

@needs_map
def test_map_iterable(self):
pass

@needs_map
def test_map_empty_sequence(self):
pass

@needs_map
def test_map_numpy(self):
pass

@pytest.mark.xfail(reason="Tracking gets disconnected from original message")
def test_scatter_tracked(self):
pass
Expand Down

0 comments on commit 142583e

Please sign in to comment.