Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update dask vine executor to new dask graphs #4015

Merged
merged 12 commits into from
Jan 17, 2025
3 changes: 1 addition & 2 deletions taskvine/src/bindings/python3/ndcctools/taskvine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,11 @@
LibraryTask,
FunctionCall,
)
from .dask_dag import DaskVineDag

from . import cvine

try:
from .dask_executor import DaskVine
from .dask_dag import DaskVineDag
except ImportError as e:
print(f"DaskVine not available. Couldn't find module: {e.name}")

Expand Down
216 changes: 93 additions & 123 deletions taskvine/src/bindings/python3/ndcctools/taskvine/dask_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,213 +2,183 @@
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.

from uuid import uuid4
from collections import defaultdict
import dask._task_spec as dts


class DaskVineDag:
"""A directed graph that encodes the steps and state a computation needs.
Single computations are encoded as s-expressions, therefore it is 'upside-down',
in the sense that the children of a node are the nodes required to compute it.
E.g., for
Single computations are encoded as dts.Task's, with dependecies expressed as the keys needed by the task.

dsk = {'x': 1,
'y': 2,
'z': (add, 'x', 'y'),
'w': (sum, ['x', 'y', 'z']),
'v': [(sum, ['w', 'z']), 2]
'z': dts.Task('z', add, dts.TaskRef('x'), dts.TaskRef('y'))
'w': dts.Task('w', sum, [dts.TaskRef('x'), dts.TaskRef('y'), dts.TaskRef('z')]),
'v': dts.Task('v', sum, [dts.TaskRef('w'), dts.TaskRef('z')])
't': dts.Task('v', sum, [dts.TaskRef('v'), 2])
}

'z' has as children 'x' and 'y'.

Each node is referenced by its key. When the value of a key is list of
sexprs, like 'v' above, and low_memory_mode is True, then a key is automatically computed recursively
for each computation.
'z' has as dependecies 'x' and 'y'.

Computation is done lazily. The DaskVineDag is initialized from a task graph, but not
computation is decoded. To use the DaskVineDag:
- DaskVineDag.set_targets(keys): Request the computation associated with key to be decoded.
- DaskVineDag.get_ready(): A list of [key, sexpr] of expressions that are ready
to be executed.
- DaskVineDag.get_ready(): A list of dts.Task that are ready to be executed.
- DaskVineDag.set_result(key, value): Sets the result of key to value.
- DaskVineDag.get_result(key): Get result associated with key. Raises DagNoResult
- DaskVineDag.has_result(key): Whether the key has a computed result. """

@staticmethod
def hashable(s):
try:
hash(s)
return True
except TypeError:
return False

@staticmethod
def keyp(s):
return DaskVineDag.hashable(s) and not DaskVineDag.taskp(s)
return DaskVineDag.hashable(s) and not DaskVineDag.taskref(s) and not DaskVineDag.taskp(s)

@staticmethod
def taskp(s):
return isinstance(s, tuple) and len(s) > 0 and callable(s[0])
def taskref(s):
return isinstance(s, (dts.TaskRef, dts.Alias))

@staticmethod
def listp(s):
return isinstance(s, list)
def taskp(s):
return isinstance(s, dts.Task)

@staticmethod
def symbolp(s):
return not (DaskVineDag.taskp(s) or DaskVineDag.listp(s))
def containerp(s):
return isinstance(s, dts.NestedContainer)

@staticmethod
def hashable(s):
try:
hash(s)
return True
except TypeError:
return False
def symbolp(s):
return isinstance(s, dts.DataNode)

def __init__(self, dsk, low_memory_mode=False):
def __init__(self, dsk):
self._dsk = dsk

# child -> parents. I.e., which parents needs the result of child
self._parents_of = defaultdict(lambda: set())
# For a key, the set of keys that need it to perform a computation.
self._needed_by = defaultdict(lambda: set())

# parent->children still waiting for result. A key is ready to be computed when children left is []
self._missing_of = {}
# For a key, the subset of self._needed_by[key] that still need to be completed.
# Only useful for gc.
self._pending_needed_by = defaultdict(lambda: set())

# parent->nchildren get the number of children for parent computation
self._children_of = {}
# For a key, the set of keys that it needs for computation.
self._dependencies_of = {}

# For a key, the set of keys with a pending result for they key to be computed.
# When the set is empty, the key is ready to be computed. It is always a subset
# of self._dependencies_of[key].
self._missing_of = {}

# key->value of its computation
self._result_of = {}

# child -> nodes that use the child as an input, and that have not been completed
self._pending_parents_of = defaultdict(lambda: set())

# key->depth. The shallowest level the key is found
self._depth_of = defaultdict(lambda: float('inf'))

# target keys that the dag should compute
self._targets = set()

self._working_graph = dict(dsk)
if low_memory_mode:
self._flatten_graph()

self.initialize_graph()

def left_to_compute(self):
return len(self._working_graph) - len(self._result_of)

def graph_keyp(self, s):
if DaskVineDag.keyp(s):
return s in self._working_graph
return False

def depth_of(self, key):
return self._depth_of[key]

def initialize_graph(self):
for key, sexpr in self._working_graph.items():
self.set_relations(key, sexpr)

def find_dependencies(self, sexpr, depth=0):
dependencies = set()
if self.graph_keyp(sexpr):
dependencies.add(sexpr)
self._depth_of[sexpr] = min(depth, self._depth_of[sexpr])
elif not DaskVineDag.symbolp(sexpr):
for sub in sexpr:
dependencies.update(self.find_dependencies(sub, depth + 1))
return dependencies
for task in self._working_graph.values():
self.set_relations(task)

def set_relations(self, key, sexpr):
sexpr = self._working_graph[key]
for task in self._working_graph.values():
if isinstance(task, dts.DataNode):
self._depth_of[task.key] = 0
self.set_result(task.key, task.value)

self._children_of[key] = self.find_dependencies(sexpr)
self._depth_of[key] = max([self._depth_of[c] for c in self._children_of[key]]) + 1 if self._children_of[key] else 0

self._missing_of[key] = set(self._children_of[key])

for c in self._children_of[key]:
self._parents_of[c].add(key)
self._pending_parents_of[c].add(key)
def set_relations(self, task):
self._dependencies_of[task.key] = task.dependencies
self._missing_of[task.key] = set(self._dependencies_of[task.key])
for c in self._dependencies_of[task.key]:
self._needed_by[c].add(task.key)
self._pending_needed_by[c].add(task.key)

def get_ready(self):
""" List of [(key, sexpr),...] ready for computation.
""" List of dts.Task ready for computation.
This call should be used only for
bootstrapping. Further calls should use DaskVineDag.set_result to discover
the new computations that become ready to be executed. """
rs = {}
for (key, cs) in self._missing_of.items():
if self.has_result(key) or cs:
continue
sexpr = self._working_graph[key]
if self.graph_keyp(sexpr):
rs.update(self.set_result(key, self.get_result(sexpr)))
elif self.symbolp(sexpr):
rs.update(self.set_result(key, sexpr))
node = self._working_graph[key]
if self.taskref(node):
rs.update(self.set_result(key, self.get_result(node.key)))
elif self.symbolp(node):
rs.update(self.set_result(key, node))
else:
rs[key] = (key, sexpr)
rs[key] = node

for r in rs:
if self._dependencies_of[r]:
self._depth_of[r] = min(self._depth_of[d] for d in self._dependencies_of[r]) + 1
else:
self._depth_of[r] = 0

return rs.values()

def set_result(self, key, value):
""" Sets new result and propagates in the DaskVineDag. Returns a list of [(key, sexpr),...]
""" Sets new result and propagates in the DaskVineDag. Returns a list of dts.Task
of computations that become ready to be executed """
rs = {}
self._result_of[key] = value
for p in self._parents_of[key]:
for p in self._pending_needed_by[key]:
self._missing_of[p].discard(key)

if self._missing_of[p]:
# the key p still has dependencies unmet...
continue

sexpr = self._working_graph[p]
if self.graph_keyp(sexpr):
node = self._working_graph[p]
if self.taskref(node):
rs.update(
self.set_result(p, self.get_result(sexpr))
self.set_result(p, self.get_result(node))
) # case e.g, "x": "y", and we just set the value of "y"
elif self.symbolp(sexpr):
rs.update(self.set_result(p, sexpr))
elif self.symbolp(node):
rs.update(self.set_result(p, node))
else:
rs[p] = (p, sexpr)
rs[p] = node

for c in self._children_of[key]:
self._pending_parents_of[c].discard(key)
for r in rs:
if self._dependencies_of[r]:
self._depth_of[r] = min(self._depth_of[d] for d in self._dependencies_of[r]) + 1
else:
self._depth_of[r] = 0

return rs.values()
for c in self._dependencies_of[key]:
self._pending_needed_by[c].discard(key)

def _flatten_graph(self):
""" Recursively decomposes a sexpr associated with key, so that its arguments, if any
are keys. """
for key in list(self._working_graph.keys()):
self.flatten_rec(key, self._working_graph[key], toplevel=True)
return rs.values()

def _add_second_targets(self, key):
v = self._working_graph[key]
if self.graph_keyp(v):
if self.taskref(v):
lst = [v]
elif DaskVineDag.listp(v):
elif DaskVineDag.containerp(v):
lst = v
else:
return
for c in lst:
if self.graph_keyp(c):
self._targets.add(c)
self._add_second_targets(c)

def flatten_rec(self, key, sexpr, toplevel=False):
if key in self._working_graph and not toplevel:
return
if DaskVineDag.symbolp(sexpr):
return

nargs = []
next_flat = []
cons = type(sexpr)

for arg in sexpr:
if DaskVineDag.symbolp(arg):
nargs.append(arg)
else:
next_key = uuid4()
nargs.append(next_key)
next_flat.append((next_key, arg))

self._working_graph[key] = cons(nargs)
for (n, a) in next_flat:
self.flatten_rec(n, a)
if self.taskref(c):
self._targets.add(c.key)
self._add_second_targets(c.key)

def has_result(self, key):
return key in self._result_of
Expand All @@ -219,17 +189,17 @@ def get_result(self, key):
except KeyError:
raise DaskVineNoResult(key)

def get_children(self, key):
return self._children_of[key]
def get_dependencies(self, key):
return self._dependencies_of[key]

def get_missing_children(self, key):
def get_missing_dependencies(self, key):
return self._missing_of[key]

def get_parents(self, key):
return self._parents_of[key]
def get_needed_by(self, key):
return self._needed_by[key]

def get_pending_parents(self, key):
return self._pending_parents_of[key]
def get_pending_needed_by(self, key):
return self._pending_needed_by[key]

def set_targets(self, keys):
""" Values of keys that need to be computed. """
Expand Down
Loading
Loading