|
5 | 5 | import copy
|
6 | 6 | import logging
|
7 | 7 | import os
|
| 8 | +import platform |
| 9 | +from concurrent.futures import ( |
| 10 | + FIRST_COMPLETED, |
| 11 | + ProcessPoolExecutor, |
| 12 | + wait, |
| 13 | +) |
8 | 14 | from dataclasses import dataclass
|
9 | 15 | from typing import Callable, Dict, Optional, Union
|
10 | 16 |
|
@@ -46,16 +52,20 @@ def _get_loader(self) -> Callable:
|
46 | 52 | assert callable(loader)
|
47 | 53 | return loader
|
48 | 54 |
|
49 |
| - def load_tasks(self, parameters, loaded_tasks, write_artifacts): |
| 55 | + def load_tasks(self, parameters, kind_dependencies_tasks, write_artifacts): |
| 56 | + logger.debug(f"Loading tasks for kind {self.name}") |
| 57 | + |
| 58 | + parameters = Parameters(**parameters) |
50 | 59 | loader = self._get_loader()
|
51 | 60 | config = copy.deepcopy(self.config)
|
52 | 61 |
|
53 |
| - kind_dependencies = config.get("kind-dependencies", []) |
54 |
| - kind_dependencies_tasks = { |
55 |
| - task.label: task for task in loaded_tasks if task.kind in kind_dependencies |
56 |
| - } |
57 |
| - |
58 |
| - inputs = loader(self.name, self.path, config, parameters, loaded_tasks) |
| 62 | + inputs = loader( |
| 63 | + self.name, |
| 64 | + self.path, |
| 65 | + config, |
| 66 | + parameters, |
| 67 | + list(kind_dependencies_tasks.values()), |
| 68 | + ) |
59 | 69 |
|
60 | 70 | transforms = TransformSequence()
|
61 | 71 | for xform_path in config["transforms"]:
|
@@ -89,6 +99,7 @@ def load_tasks(self, parameters, loaded_tasks, write_artifacts):
|
89 | 99 | )
|
90 | 100 | for task_dict in transforms(trans_config, inputs)
|
91 | 101 | ]
|
| 102 | + logger.info(f"Generated {len(tasks)} tasks for kind {self.name}") |
92 | 103 | return tasks
|
93 | 104 |
|
94 | 105 | @classmethod
|
@@ -253,6 +264,101 @@ def _load_kinds(self, graph_config, target_kinds=None):
|
253 | 264 | except KindNotFound:
|
254 | 265 | continue
|
255 | 266 |
|
| 267 | + def _load_tasks_serial(self, kinds, kind_graph, parameters): |
| 268 | + all_tasks = {} |
| 269 | + for kind_name in kind_graph.visit_postorder(): |
| 270 | + logger.debug(f"Loading tasks for kind {kind_name}") |
| 271 | + |
| 272 | + kind = kinds.get(kind_name) |
| 273 | + if not kind: |
| 274 | + message = f'Could not find the kind "{kind_name}"\nAvailable kinds:\n' |
| 275 | + for k in sorted(kinds): |
| 276 | + message += f' - "{k}"\n' |
| 277 | + raise Exception(message) |
| 278 | + |
| 279 | + try: |
| 280 | + new_tasks = kind.load_tasks( |
| 281 | + parameters, |
| 282 | + { |
| 283 | + k: t |
| 284 | + for k, t in all_tasks.items() |
| 285 | + if t.kind in kind.config.get("kind-dependencies", []) |
| 286 | + }, |
| 287 | + self._write_artifacts, |
| 288 | + ) |
| 289 | + except Exception: |
| 290 | + logger.exception(f"Error loading tasks for kind {kind_name}:") |
| 291 | + raise |
| 292 | + for task in new_tasks: |
| 293 | + if task.label in all_tasks: |
| 294 | + raise Exception("duplicate tasks with label " + task.label) |
| 295 | + all_tasks[task.label] = task |
| 296 | + |
| 297 | + return all_tasks |
| 298 | + |
| 299 | + def _load_tasks_parallel(self, kinds, kind_graph, parameters): |
| 300 | + all_tasks = {} |
| 301 | + futures_to_kind = {} |
| 302 | + futures = set() |
| 303 | + edges = set(kind_graph.edges) |
| 304 | + |
| 305 | + with ProcessPoolExecutor() as executor: |
| 306 | + |
| 307 | + def submit_ready_kinds(): |
| 308 | + """Create the next batch of tasks for kinds without dependencies.""" |
| 309 | + nonlocal kinds, edges, futures |
| 310 | + loaded_tasks = all_tasks.copy() |
| 311 | + kinds_with_deps = {edge[0] for edge in edges} |
| 312 | + ready_kinds = ( |
| 313 | + set(kinds) - kinds_with_deps - set(futures_to_kind.values()) |
| 314 | + ) |
| 315 | + for name in ready_kinds: |
| 316 | + kind = kinds.get(name) |
| 317 | + if not kind: |
| 318 | + message = ( |
| 319 | + f'Could not find the kind "{name}"\nAvailable kinds:\n' |
| 320 | + ) |
| 321 | + for k in sorted(kinds): |
| 322 | + message += f' - "{k}"\n' |
| 323 | + raise Exception(message) |
| 324 | + |
| 325 | + future = executor.submit( |
| 326 | + kind.load_tasks, |
| 327 | + dict(parameters), |
| 328 | + { |
| 329 | + k: t |
| 330 | + for k, t in loaded_tasks.items() |
| 331 | + if t.kind in kind.config.get("kind-dependencies", []) |
| 332 | + }, |
| 333 | + self._write_artifacts, |
| 334 | + ) |
| 335 | + futures.add(future) |
| 336 | + futures_to_kind[future] = name |
| 337 | + |
| 338 | + submit_ready_kinds() |
| 339 | + while futures: |
| 340 | + done, _ = wait(futures, return_when=FIRST_COMPLETED) |
| 341 | + for future in done: |
| 342 | + if exc := future.exception(): |
| 343 | + executor.shutdown(wait=False, cancel_futures=True) |
| 344 | + raise exc |
| 345 | + kind = futures_to_kind.pop(future) |
| 346 | + futures.remove(future) |
| 347 | + |
| 348 | + for task in future.result(): |
| 349 | + if task.label in all_tasks: |
| 350 | + raise Exception("duplicate tasks with label " + task.label) |
| 351 | + all_tasks[task.label] = task |
| 352 | + |
| 353 | + # Update state for next batch of futures. |
| 354 | + del kinds[kind] |
| 355 | + edges = {e for e in edges if e[1] != kind} |
| 356 | + |
| 357 | + # Submit any newly unblocked kinds |
| 358 | + submit_ready_kinds() |
| 359 | + |
| 360 | + return all_tasks |
| 361 | + |
256 | 362 | def _run(self):
|
257 | 363 | logger.info("Loading graph configuration.")
|
258 | 364 | graph_config = load_graph_config(self.root_dir)
|
@@ -307,31 +413,18 @@ def _run(self):
|
307 | 413 | )
|
308 | 414 |
|
309 | 415 | logger.info("Generating full task set")
|
310 |
| - all_tasks = {} |
311 |
| - for kind_name in kind_graph.visit_postorder(): |
312 |
| - logger.debug(f"Loading tasks for kind {kind_name}") |
313 |
| - |
314 |
| - kind = kinds.get(kind_name) |
315 |
| - if not kind: |
316 |
| - message = f'Could not find the kind "{kind_name}"\nAvailable kinds:\n' |
317 |
| - for k in sorted(kinds): |
318 |
| - message += f' - "{k}"\n' |
319 |
| - raise Exception(message) |
| 416 | + # Current parallel generation relies on multiprocessing, and forking. |
| 417 | + # This causes problems on Windows and macOS due to how new processes |
| 418 | + # are created there, and how doing so reinitializes global variables |
| 419 | + # that are modified earlier in graph generation, that doesn't get |
| 420 | + # redone in the new processes. Ideally this would be fixed, or we |
| 421 | + # would take another approach to parallel kind generation. In the |
| 422 | + # meantime, it's not supported outside of Linux. |
| 423 | + if platform.system() != "Linux": |
| 424 | + all_tasks = self._load_tasks_serial(kinds, kind_graph, parameters) |
| 425 | + else: |
| 426 | + all_tasks = self._load_tasks_parallel(kinds, kind_graph, parameters) |
320 | 427 |
|
321 |
| - try: |
322 |
| - new_tasks = kind.load_tasks( |
323 |
| - parameters, |
324 |
| - list(all_tasks.values()), |
325 |
| - self._write_artifacts, |
326 |
| - ) |
327 |
| - except Exception: |
328 |
| - logger.exception(f"Error loading tasks for kind {kind_name}:") |
329 |
| - raise |
330 |
| - for task in new_tasks: |
331 |
| - if task.label in all_tasks: |
332 |
| - raise Exception("duplicate tasks with label " + task.label) |
333 |
| - all_tasks[task.label] = task |
334 |
| - logger.info(f"Generated {len(new_tasks)} tasks for kind {kind_name}") |
335 | 428 | full_task_set = TaskGraph(all_tasks, Graph(frozenset(all_tasks), frozenset()))
|
336 | 429 | yield self.verify("full_task_set", full_task_set, graph_config, parameters)
|
337 | 430 |
|
|
0 commit comments