Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.

replace callback #47

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions dask4dvc/dvc_repro.py
Original file line number Diff line number Diff line change
@@ -117,28 +117,25 @@ def reproduce_experiment(entry_dict: dict, infofile: str, successors: list) -> s
finally:
executor.cleanup(infofile)

return executor.info.name

with dask.distributed.Lock("dvc"):
repo = dvc.repo.Repo()
queue = repo.experiments.celery_queue
for msg in queue.celery.iter_queued():
if msg.headers.get("task") != tasks.run_exp.name:
continue
args, kwargs, _embed = msg.decode()
entry_dict = kwargs.get("entry_dict", args[0])
if entry_dict["name"] == executor.info.name:
queue.celery.reject(msg.delivery_tag)
if dask.distributed.Variable("cleanup").get():
# this one should only be called if the experiment
# should truly be removed
dvc.cli.main(["exp", "remove", executor.info.name])
if dask.distributed.Variable("repro").get():
# load experiments results into workspace
dvc.cli.main(["repro", "--single-item", executor.info.name])

def get_experiment_callback(name: dask.distributed.Future) -> None:
"""Get callback to run after an experiment is done."""
name = name.result()
with dask.distributed.Lock("dvc"):
repo = dvc.repo.Repo()
queue = repo.experiments.celery_queue
for msg in queue.celery.iter_queued():
if msg.headers.get("task") != tasks.run_exp.name:
continue
args, kwargs, _embed = msg.decode()
entry_dict = kwargs.get("entry_dict", args[0])
if entry_dict["name"] == name:
queue.celery.reject(msg.delivery_tag)
if dask.distributed.Variable("cleanup").get():
# this one should only be called if the experiment should truly be removed
dvc.cli.main(["exp", "remove", name])
if dask.distributed.Variable("repro").get():
# load experiments results into workspace
dvc.cli.main(["repro", "--single-item", name])
return executor.info.name


def submit_to_dask(
@@ -153,7 +150,6 @@ def submit_to_dask(
pure=False,
key=entry.name,
)
experiment.add_done_callback(get_experiment_callback)
return experiment