Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vine: file pruning by depth #4057

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 58 additions & 11 deletions taskvine/src/bindings/python3/ndcctools/taskvine/compat/dask_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def hashable(s):
except TypeError:
return False

def __init__(self, dsk, low_memory_mode=False):
def __init__(self, dsk, low_memory_mode=False, prune_depth=0):
self._dsk = dsk

# child -> parents. I.e., which parents needs the result of child
Expand All @@ -73,9 +73,6 @@ def __init__(self, dsk, low_memory_mode=False):
# key->value of its computation
self._result_of = {}

# child -> nodes that use the child as an input, and that have not been completed
self._pending_parents_of = defaultdict(lambda: set())

# key->depth. The shallowest level the key is found
self._depth_of = defaultdict(lambda: float('inf'))

Expand All @@ -86,6 +83,10 @@ def __init__(self, dsk, low_memory_mode=False):
if low_memory_mode:
self._flatten_graph()

self.prune_depth = prune_depth
self.pending_consumers = defaultdict(int)
self.pending_producers = defaultdict(lambda: set())

self.initialize_graph()

def left_to_compute(self):
Expand All @@ -103,6 +104,11 @@ def initialize_graph(self):
for key, sexpr in self._working_graph.items():
self.set_relations(key, sexpr)

# Then initialize pending consumers if pruning is enabled
if self.prune_depth > 0:
self._initialize_pending_consumers()
self._initialize_pending_producers()

def find_dependencies(self, sexpr, depth=0):
dependencies = set()
if self.graph_keyp(sexpr):
Expand All @@ -123,7 +129,53 @@ def set_relations(self, key, sexpr):

for c in self._children_of[key]:
self._parents_of[c].add(key)
self._pending_parents_of[c].add(key)

def _initialize_pending_consumers(self):
"""Initialize pending consumers counts based on prune_depth"""
for key in self._working_graph:
if key not in self.pending_consumers:
count = 0
# BFS to count consumers up to prune_depth
visited = set()
queue = [(c, 1) for c in self._parents_of[key]] # (consumer, depth)

while queue:
consumer, depth = queue.pop(0)
if depth <= self.prune_depth and consumer not in visited:
visited.add(consumer)
count += 1

# Add next level consumers if we haven't reached max depth
if depth < self.prune_depth:
next_consumers = [(c, depth + 1) for c in self._parents_of[consumer]]
queue.extend(next_consumers)

self.pending_consumers[key] = count

def _initialize_pending_producers(self):
"""Initialize pending producers based on prune_depth"""
if self.prune_depth <= 0:
return

for key in self._working_graph:
# Use set to store unique producers
producers = set()
visited = set()
queue = [(p, 1) for p in self._children_of[key]] # (producer, depth)

while queue:
producer, depth = queue.pop(0)
if depth <= self.prune_depth and producer not in visited:
visited.add(producer)
producers.add(producer)

# Add next level producers if we haven't reached max depth
if depth < self.prune_depth:
next_producers = [(p, depth + 1) for p in self._children_of[producer]]
queue.extend(next_producers)

# Store all producers for this key in pending_producers
self.pending_producers[key] = producers

def get_ready(self):
""" List of [(key, sexpr),...] ready for computation.
Expand All @@ -148,6 +200,7 @@ def set_result(self, key, value):
of computations that become ready to be executed """
rs = {}
self._result_of[key] = value

for p in self._parents_of[key]:
self._missing_of[p].discard(key)

Expand All @@ -164,9 +217,6 @@ def set_result(self, key, value):
else:
rs[p] = (p, sexpr)

for c in self._children_of[key]:
self._pending_parents_of[c].discard(key)

return rs.values()

def _flatten_graph(self):
Expand Down Expand Up @@ -228,9 +278,6 @@ def get_missing_children(self, key):
def get_parents(self, key):
return self._parents_of[key]

def get_pending_parents(self, key):
return self._pending_parents_of[key]

def set_targets(self, keys):
""" Values of keys that need to be computed. """
self._targets.update(keys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DaskVine(Manager):
# fn(*args) at some point during its execution to produce the dask task result.
# Should return a tuple of (wrapper result, dask call result). Use for debugging.
# @param wrapper_proc Function to process results from wrapper on completion. (default is print)
# @param prune_files If True, remove files from the cluster after they are no longer needed.
# @param prune_depth Control pruning behavior: 0 (default) - no pruning, 1 - only check direct consumers, 2+ - check consumers up to specified depth
def get(self, dsk, keys, *,
environment=None,
extra_files=None,
Expand All @@ -132,7 +132,7 @@ def get(self, dsk, keys, *,
progress_label="[green]tasks",
wrapper=None,
wrapper_proc=print,
prune_files=False,
prune_depth=0,
hoisting_modules=None, # Deprecated, use lib_modules
import_modules=None, # Deprecated, use lib_modules
lazy_transfers=True, # Deprecated, use worker_tranfers
Expand Down Expand Up @@ -174,7 +174,7 @@ def get(self, dsk, keys, *,
self.progress_label = progress_label
self.wrapper = wrapper
self.wrapper_proc = wrapper_proc
self.prune_files = prune_files
self.prune_depth = prune_depth
self.category_info = defaultdict(lambda: {"num_tasks": 0, "total_execution_time": 0})
self.max_priority = float('inf')
self.min_priority = float('-inf')
Expand Down Expand Up @@ -212,7 +212,7 @@ def _dask_execute(self, dsk, keys):
indices = {k: inds for (k, inds) in find_dask_keys(keys)}
keys_flatten = indices.keys()

dag = DaskVineDag(dsk, low_memory_mode=self.low_memory_mode)
dag = DaskVineDag(dsk, low_memory_mode=self.low_memory_mode, prune_depth=self.prune_depth)
tag = f"dag-{id(dag)}"

# create Library if using 'function-calls' task mode.
Expand Down Expand Up @@ -294,8 +294,12 @@ def _dask_execute(self, dsk, keys):
if t.key in dsk:
bar_update(advance=1)

if self.prune_files:
self._prune_file(dag, t.key)
if self.prune_depth > 0:
for p in dag.pending_producers[t.key]:
dag.pending_consumers[p] -= 1
if dag.pending_consumers[p] == 0:
p_result = dag.get_result(p)
self.prune_file(p_result._file)
else:
retries_left = t.decrement_retry()
print(f"task id {t.id} key {t.key} failed: {t.result}. {retries_left} attempts left.\n{t.std_output}")
Expand Down Expand Up @@ -446,14 +450,6 @@ def _fill_key_result(self, dag, key):
return raw.load()
else:
return raw

def _prune_file(self, dag, key):
children = dag.get_children(key)
for c in children:
if len(dag.get_pending_parents(c)) == 0:
c_result = dag.get_result(c)
self.prune_file(c_result._file)

##
# @class ndcctools.taskvine.dask_executor.DaskVineFile
#
Expand Down