Skip to content

WIP: experiment generating kinds in parallel (multiprocess) #717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 76 additions & 25 deletions src/taskgraph/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from collections import defaultdict
import copy
from itertools import chain
import logging
import os
from concurrent.futures import (
ALL_COMPLETED,
FIRST_COMPLETED,
ProcessPoolExecutor,
as_completed,
wait,
)
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Union

Expand Down Expand Up @@ -44,16 +53,20 @@ def _get_loader(self):
loader = "taskgraph.loader.default:loader"
return find_object(loader)

def load_tasks(self, parameters, loaded_tasks, write_artifacts):
def load_tasks(self, parameters, kind_dependencies_tasks, write_artifacts):
logger.debug(f"Loading tasks for kind {self.name}")

parameters = Parameters(**parameters)
loader = self._get_loader()
config = copy.deepcopy(self.config)

kind_dependencies = config.get("kind-dependencies", [])
kind_dependencies_tasks = {
task.label: task for task in loaded_tasks if task.kind in kind_dependencies
}

inputs = loader(self.name, self.path, config, parameters, loaded_tasks)
inputs = loader(
self.name,
self.path,
config,
parameters,
list(kind_dependencies_tasks.values()),
)

transforms = TransformSequence()
for xform_path in config["transforms"]:
Expand Down Expand Up @@ -87,6 +100,7 @@ def load_tasks(self, parameters, loaded_tasks, write_artifacts):
)
for task_dict in transforms(trans_config, inputs)
]
logger.info(f"Generated {len(tasks)} tasks for kind {self.name}")
return tasks

@classmethod
Expand Down Expand Up @@ -249,6 +263,59 @@ def _load_kinds(self, graph_config, target_kinds=None):
except KindNotFound:
continue

def _load_tasks(self, kinds, kind_graph, parameters):
all_tasks = {}
futures_to_kind = {}
futures = set()
edges = set(kind_graph.edges)

def add_new_tasks(future):
for task in future.result():
if task.label in all_tasks:
raise Exception("duplicate tasks with label " + task.label)
all_tasks[task.label] = task

with ProcessPoolExecutor() as executor:

def submit_ready_kinds():
"""Create the next batch of tasks for kinds without dependencies."""
nonlocal kinds, edges, futures
loaded_tasks = all_tasks.copy()
kinds_with_deps = {edge[0] for edge in edges}
ready_kinds = (
set(kinds) - kinds_with_deps - set(futures_to_kind.values())
)
for name in ready_kinds:
kind = kinds[name]
future = executor.submit(
kind.load_tasks,
dict(parameters),
{
k: t
for k, t in loaded_tasks.items()
if t.kind in kind.config.get("kind-dependencies", [])
},
self._write_artifacts,
)
future.add_done_callback(add_new_tasks)
futures.add(future)
futures_to_kind[future] = name

submit_ready_kinds()
while futures:
for future in as_completed(futures):
kind = futures_to_kind.pop(future)
futures.remove(future)

# Update state for next batch of futures.
del kinds[kind]
edges = {e for e in edges if e[1] != kind}

# Submit any newly unblocked kinds
submit_ready_kinds()

return all_tasks

def _run(self):
logger.info("Loading graph configuration.")
graph_config = load_graph_config(self.root_dir)
Expand Down Expand Up @@ -303,24 +370,8 @@ def _run(self):
)

logger.info("Generating full task set")
all_tasks = {}
for kind_name in kind_graph.visit_postorder():
logger.debug(f"Loading tasks for kind {kind_name}")
kind = kinds[kind_name]
try:
new_tasks = kind.load_tasks(
parameters,
list(all_tasks.values()),
self._write_artifacts,
)
except Exception:
logger.exception(f"Error loading tasks for kind {kind_name}:")
raise
for task in new_tasks:
if task.label in all_tasks:
raise Exception("duplicate tasks with label " + task.label)
all_tasks[task.label] = task
logger.info(f"Generated {len(new_tasks)} tasks for kind {kind_name}")
all_tasks = self._load_tasks(kinds, kind_graph, parameters)

full_task_set = TaskGraph(all_tasks, Graph(frozenset(all_tasks), frozenset()))
yield self.verify("full_task_set", full_task_set, graph_config, parameters)

Expand Down
Loading