Skip to content

Commit b807cf3

Browse files
committed
tasks: refactor run_all_tasks to TaskPlan
The run_all_tasks() function was getting large and unwieldy. Upcoming task features (support for including files) will require even more complicated logic. To prepare, refactor this function into a class which contains the shared state, and multiple methods that implement parts of the process. One nice benefit here is that we can now do the initial checking (loading dependencies, checking for conflicts) prior to launching an instance. This further helps avoid the case where bad task configuration could cause the instance setup to fail after an otherwise successful instance launch. Another nice benefit is that we can print the plan for execution as part of the --dry-run, which means we can also add a --dry-run option for "yo task run". Aside from these two benefits, no other functional change is intended. Signed-off-by: Stephen Brennan <[email protected]>
1 parent ad9222f commit b807cf3

File tree

2 files changed

+123
-74
lines changed

2 files changed

+123
-74
lines changed

yo/main.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@
115115
from yo.ssh import SSH_OPTIONS
116116
from yo.ssh import wait_for_ssh_access
117117
from yo.tasks import list_tasks
118-
from yo.tasks import run_all_tasks
119118
from yo.tasks import task_get_status
120119
from yo.tasks import task_join
121120
from yo.tasks import task_status_to_table
121+
from yo.tasks import TaskPlan
122122
from yo.tasks import YoTask
123123
from yo.util import current_yo_version
124124
from yo.util import fmt_allow_deny
@@ -1232,11 +1232,21 @@ def add_args(self, parser: argparse.ArgumentParser) -> None:
12321232
action="store_true",
12331233
help="should we wait until the task is finished?",
12341234
)
1235+
parser.add_argument(
1236+
"--dry-run",
1237+
action="store_true",
1238+
help="just print what would happen",
1239+
)
12351240

12361241
def run_for_instance(self, inst: YoInstance) -> None:
1237-
run_all_tasks(self.c, inst, self.args.task)
1242+
plan = TaskPlan(self.args.task)
1243+
plan.prepare()
1244+
if self.args.dry_run:
1245+
plan.dry_run_print()
1246+
return
1247+
plan.run(self.c, inst)
12381248
if self.args.wait:
1239-
task_join(self.c, inst, wait_tasks=self.args.task)
1249+
plan.join(self.c, inst)
12401250
send_notification(
12411251
self.c,
12421252
f"Task {self.args.task} complete on instance {inst.name}",
@@ -1842,17 +1852,18 @@ def run(self) -> None:
18421852
# Load tasks before we launch. That way an invalid configuration is
18431853
# detected ASAP, and the user could correct the config before we've
18441854
# actually launched.
1845-
tasks = set(profile.tasks + self.args.tasks)
1855+
task_plan = TaskPlan(profile.tasks + self.args.tasks)
1856+
task_plan.prepare()
18461857

18471858
self.c.con.log(f"Launching instance [blue]{name}[/blue]")
18481859
if self.args.dry_run:
18491860
self.c.con.log("DRY RUN. Args below:")
18501861
self.c.con.log(create_args)
1851-
self.c.con.log(f"Will launch tasks: {tasks}")
1862+
task_plan.dry_run_print()
18521863
return
18531864
inst = self.c.launch_instance(create_args)
18541865

1855-
self.standardize_wait(bool(tasks))
1866+
self.standardize_wait(task_plan.have_tasks())
18561867

18571868
if not self.args.wait:
18581869
return
@@ -1868,9 +1879,9 @@ def run(self) -> None:
18681879
self.c.con.log("Maybe you're not connected to VPN?")
18691880
return
18701881

1871-
run_all_tasks(self.c, inst, tasks)
1872-
if tasks:
1873-
task_join(self.c, inst)
1882+
task_plan.run(self.c, inst)
1883+
if task_plan.have_tasks():
1884+
task_plan.join(self.c, inst)
18741885

18751886
send_notification(self.c, f"Instance {inst.name} is ready!")
18761887

yo/tasks.py

+103-65
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
os.path.expanduser("~/.oci/yo-tasks"),
6161
# This should be installed with the package
6262
os.path.join(os.path.abspath(os.path.dirname(__file__)), "data/yo-tasks"),
63+
# Uncomment for testing, if running from a git checkout:
64+
# os.path.join(os.path.abspath(os.path.dirname(__file__)), "../test-tasks"),
6365
]
6466

6567

@@ -200,72 +202,108 @@ def _task_run(ctx: "YoCtx", inst: YoInstance, task: YoTask) -> None:
200202
ssh_into(ip, user, ctx, extra_args=["-q"], cmds=[commands], quiet=True)
201203

202204

203-
def run_all_tasks(
204-
ctx: "YoCtx", inst: YoInstance, tasks: t.Iterable[t.Union[YoTask, str]]
205-
) -> None:
206-
# The caller may specify tasks as either strings or YoTask instances, for
207-
# convenience. Let's get everything into a "name_to_task" dict.
208-
name_to_task: t.Dict[str, YoTask] = {}
209-
for task in tasks:
210-
if isinstance(task, str):
211-
name_to_task[task] = YoTask.load(task)
212-
else:
213-
name_to_task[task.name] = task
214-
215-
# Tasks may have dependencies. Let's go through every task, and their
216-
# dependencies, and load them all. At this point, we're not yet checking
217-
# whether there are any circular dependencies: just loading them.
218-
tasks_to_load = list(name_to_task.values())
219-
for task in tasks_to_load:
220-
for name in task.dependencies:
221-
if name not in name_to_task:
222-
name_to_task[name] = YoTask.load(name)
223-
tasks_to_load.append(name_to_task[name])
224-
225-
# Now we have loaded the complete set of tasks that should run. Some tasks
226-
# may appoint themselves as "prerequisites" for another. We need to insert
227-
# this dependency relationship so that the script is updated, and so that
228-
# the circular dependency detection knows about it. We can also use this
229-
# opportunity to detect conflicts.
230-
for task in name_to_task.values():
231-
for name in task.prereq_for:
232-
if name in name_to_task:
233-
name_to_task[name].insert_prereq(task.name)
234-
for name in task.conflicts:
235-
if name in name_to_task:
236-
raise YoExc(f"Task {task.name} conflicts with {name}")
237-
238-
# Now all tasks are loaded, and prerequisites have been marked. Use a
239-
# topological sort to verify that no circular dependencies are present. Here
240-
# we use a recursive traversal because honestly, if you specify enough tasks
241-
# to trigger a recursion error, then I would like to receive that bug
242-
# report!
243-
name_to_visit: t.Dict[str, int] = collections.defaultdict(int)
244-
ordered_tasks: t.List[YoTask] = []
245-
246-
def visit(task: YoTask) -> None:
247-
if name_to_visit[task.name] == 2:
248-
# already completed, skip
205+
class TaskPlan:
206+
"""
207+
A group of tasks to be run altogether on an instance
208+
209+
This plan encompasses the necessary files to copy over, as well as the order
210+
of the tasks to run.
211+
"""
212+
213+
name_to_task: t.Dict[str, YoTask]
214+
ordered_tasks: t.List[YoTask]
215+
216+
def _prepare_prereqs_check_conflicts(self) -> None:
217+
# Now we have loaded the complete set of tasks that should run. Some tasks
218+
# may appoint themselves as "prerequisites" for another. We need to insert
219+
# this dependency relationship so that the script is updated, and so that
220+
# the circular dependency detection knows about it. We can also use this
221+
# opportunity to detect conflicts.
222+
for task in self.name_to_task.values():
223+
for name in task.prereq_for:
224+
if name in self.name_to_task:
225+
self.name_to_task[name].insert_prereq(task.name)
226+
for name in task.conflicts:
227+
if name in self.name_to_task:
228+
raise YoExc(f"Task {task.name} conflicts with {name}")
229+
230+
def _create_execution_order(self) -> None:
231+
# Now all tasks are loaded, and prerequisites have been marked. Use a
232+
# topological sort to verify that no circular dependencies are present. Here
233+
# we use a recursive traversal because honestly, if you specify enough tasks
234+
# to trigger a recursion error, then I would like to receive that bug
235+
# report!
236+
name_to_visit: t.Dict[str, int] = collections.defaultdict(int)
237+
self.ordered_tasks.clear() # guard against running twice
238+
239+
def visit(task: YoTask) -> None:
240+
if name_to_visit[task.name] == 2:
241+
# already completed, skip
242+
return
243+
if name_to_visit[task.name] == 1:
244+
# currently visiting, not a DAG
245+
raise YoExc("Tasks express a circular dependency")
246+
247+
name_to_visit[task.name] = 1
248+
for dep_name in task.dependencies:
249+
visit(self.name_to_task[dep_name])
250+
name_to_visit[task.name] = 2
251+
self.ordered_tasks.append(task)
252+
253+
for task in self.name_to_task.values():
254+
visit(task)
255+
256+
def __init__(self, tasks: t.Iterable[t.Union[YoTask, str]]) -> None:
257+
# The caller may specify tasks as either strings or YoTask instances, for
258+
# convenience. Let's get everything into a "name_to_task" dict.
259+
self.name_to_task = {}
260+
for task in tasks:
261+
if isinstance(task, str):
262+
self.name_to_task[task] = YoTask.load(task)
263+
else:
264+
self.name_to_task[task.name] = task
265+
266+
# Tasks may have dependencies. Let's go through every task, and their
267+
# dependencies, and load them all. At this point, we're not yet checking
268+
# whether there are any circular dependencies: just loading them.
269+
tasks_to_load = list(self.name_to_task.values())
270+
for task in tasks_to_load:
271+
for name in task.dependencies:
272+
if name not in self.name_to_task:
273+
self.name_to_task[name] = YoTask.load(name)
274+
tasks_to_load.append(self.name_to_task[name])
275+
276+
self.ordered_tasks = []
277+
278+
def prepare(self) -> None:
279+
self._prepare_prereqs_check_conflicts()
280+
self._create_execution_order()
281+
282+
def have_tasks(self) -> bool:
283+
return bool(self.ordered_tasks)
284+
285+
def dry_run_print(self) -> None:
286+
if not self.have_tasks():
287+
print("No tasks to run.")
249288
return
250-
if name_to_visit[task.name] == 1:
251-
# currently visiting, not a DAG
252-
raise YoExc("Tasks express a circular dependency")
253-
254-
name_to_visit[task.name] = 1
255-
for dep_name in task.dependencies:
256-
visit(name_to_task[dep_name])
257-
name_to_visit[task.name] = 2
258-
ordered_tasks.append(task)
259-
260-
for task in name_to_task.values():
261-
visit(task)
262-
263-
# Now ordered_tasks contains the order in which we should launch them. This
264-
# is just a nice-to-have: even if we launched them out of order, the
265-
# DEPENDS_ON function would enforce the order of execution. Regardless,
266-
# let's start the tasks.
267-
for task in ordered_tasks:
268-
_task_run(ctx, inst, task)
289+
print("Would start tasks in the following order:")
290+
for task in self.ordered_tasks:
291+
deps = ""
292+
if task.dependencies:
293+
deps = f" (depends on: {', '.join(task.dependencies)})"
294+
print(f" 1. {task.name}{deps}")
295+
296+
def run(self, ctx: "YoCtx", inst: YoInstance) -> None:
297+
# Now ordered_tasks contains the order in which we should launch them. This
298+
# is just a nice-to-have: even if we launched them out of order, the
299+
# DEPENDS_ON function would enforce the order of execution. Regardless,
300+
# let's start the tasks.
301+
for task in self.ordered_tasks:
302+
_task_run(ctx, inst, task)
303+
304+
def join(self, ctx: "YoCtx", inst: YoInstance) -> None:
305+
wait_tasks = [t.name for t in self.ordered_tasks]
306+
task_join(ctx, inst, wait_tasks)
269307

270308

271309
def task_get_status(

0 commit comments

Comments
 (0)