Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically trigger loop when possible #237

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions adaptive_scheduler/_server_support/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import json
import pickle
from dataclasses import asdict, dataclass, field
Expand Down Expand Up @@ -194,6 +195,7 @@
self._pickling_time: float | None = None
self._total_learner_size: int | None = None
self._db: SimpleDatabase | None = None
self._trigger_event = asyncio.Event()

def _setup(self) -> None:
if self.db_fname.exists() and not self.overwrite_db:
Expand All @@ -214,7 +216,9 @@
queue = self.scheduler.queue(me_only=True)
job_names_in_queue = [x["job_name"] for x in queue.values()]
failed = self._db.get_all(
lambda e: e.job_name is not None and e.job_name not in job_names_in_queue, # type: ignore[operator]
lambda e: not e.is_done
and e.job_name is not None
and e.job_name not in job_names_in_queue, # type: ignore[operator]
)
self.failed.extend([asdict(entry) for _, entry in failed])
indices = [index for index, _ in failed]
Expand Down Expand Up @@ -272,6 +276,11 @@
for f in output_fnames
]

def _done_but_still_running(self, running: dict[str, Any]) -> list[tuple[int, _DBEntry]]:
if self._db is None:
return []

Check warning on line 281 in adaptive_scheduler/_server_support/database_manager.py

View check run for this annotation

Codecov / codecov/patch

adaptive_scheduler/_server_support/database_manager.py#L281

Added line #L281 was not covered by tests
return self._db.get_all(lambda e: e.is_done and e.job_id in running)

def _choose_fname(self) -> tuple[int, str | list[str] | None]:
assert self._db is not None
entry = self._db.get(
Expand Down Expand Up @@ -338,7 +347,7 @@

def _stop_request(self, fname: str | list[str] | Path | list[Path]) -> None:
fname_str = _ensure_str(fname)
reset = {"job_id": None, "is_done": True, "job_name": None, "is_pending": False}
reset = {"is_done": True, "is_pending": False}
assert self._db is not None
entry_indices = [index for index, _ in self._db.get_all(lambda e: e.fname == fname_str)]
self._db.update(reset, entry_indices)
Expand All @@ -347,7 +356,7 @@
# Same as `_stop_request` but optimized for processing many `fnames` at once
assert self._db is not None
fnames_str = {str(fname) for fname in _ensure_str(fnames)}
reset = {"job_id": None, "is_done": True, "job_name": None, "is_pending": False}
reset = {"is_done": True, "is_pending": False}
entry_indices = [
index for index, _ in self._db.get_all(lambda e: str(e.fname) in fnames_str)
]
Expand Down Expand Up @@ -381,7 +390,8 @@
if request_type == "stop":
fname = request_arg[0] # workers send us the fname they were given
log.debug("got a stop request", fname=fname)
self._stop_request(fname) # reset the job_id to None
self._stop_request(fname) # set is_done
self.trigger_scheduling_event()
return None
except Exception as e: # noqa: BLE001
return e
Expand Down Expand Up @@ -426,3 +436,7 @@
break
finally:
socket.close()

def trigger_scheduling_event(self) -> None:
"""External method to trigger the _manage loop to continue."""
self._trigger_event.set()
15 changes: 8 additions & 7 deletions adaptive_scheduler/_server_support/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@
# Other attributes
self.n_started = 0
self._request_times: dict[str, str] = {}
self._trigger_event = asyncio.Event()

# Command line launcher options
self.save_dataframe = save_dataframe
Expand Down Expand Up @@ -217,12 +216,14 @@
running = self.scheduler.queue(me_only=True)
self.database_manager.update(running) # in case some jobs died
queued = self._queued(running) # running `job_name`s
not_queued = set(self.job_names) - queued
available_job_names = set(self.job_names) - queued
n_done = self.database_manager.n_done()
if n_done == len(self.job_names):
return None # we are finished!
n_to_schedule = max(0, len(not_queued) - n_done)
return queued, set(list(not_queued)[:n_to_schedule])
n_done_but_running = len(self.database_manager._done_but_still_running(running))
n_done_completely = n_done - n_done_but_running
n_to_schedule = max(0, len(available_job_names) - n_done_completely)
return queued, set(list(available_job_names)[:n_to_schedule])

async def _start_new_jobs(
self,
Expand Down Expand Up @@ -263,7 +264,7 @@
if await sleep_unless_task_is_done(
self.database_manager.task, # type: ignore[arg-type]
self.interval,
self._trigger_event,
self.database_manager._trigger_event,
): # if true, we are done
return
except asyncio.CancelledError: # noqa: PERF203
Expand All @@ -283,10 +284,10 @@
if await sleep_unless_task_is_done(
self.database_manager.task, # type: ignore[arg-type]
5,
self._trigger_event,
self.database_manager._trigger_event,
): # if true, we are done
return

def trigger(self) -> None:
"""External method to trigger the _manage loop to continue."""
self._trigger_event.set()
self.database_manager.trigger_scheduling_event()

Check warning on line 293 in adaptive_scheduler/_server_support/job_manager.py

View check run for this annotation

Codecov / codecov/patch

adaptive_scheduler/_server_support/job_manager.py#L293

Added line #L293 was not covered by tests
6 changes: 5 additions & 1 deletion adaptive_scheduler/_server_support/kill_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ def file_has_error(fname: Path) -> bool:
have_error = []
for entry in database_manager.as_dicts():
fnames = entry["output_logs"]
if entry["job_id"] is not None and any(file_has_error(Path(f)) for f in fnames):
if (
not entry["is_done"]
and entry["job_id"] is not None
and any(file_has_error(Path(f)) for f in fnames)
):
all_fnames = [*fnames, entry["log_fname"]]
have_error.append((entry["job_name"], all_fnames))
return have_error
Expand Down
32 changes: 20 additions & 12 deletions tests/test_database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ async def test_database_manager_dispatch_start_stop(
assert db_manager._db is not None
entry = db_manager._db.get(lambda entry: entry.fname == fname)
assert entry is not None
assert entry.job_id is None
assert entry.job_id == "1000"
assert entry.is_done is True


Expand Down Expand Up @@ -305,19 +305,27 @@ async def test_database_manager_start_stop(
assert db_manager._db is not None
entry = db_manager._db.get(lambda entry: entry.fname == _ensure_str(fnames[0]))
assert entry is not None
assert entry.job_id is None
assert entry.job_id == job_id

# Start and stop the learner2
index2, _ = db_manager._choose_fname()
db_manager._confirm_submitted(index2, job_name)
fname = await send_message(socket, start_message)
job_id2, job_name2 = "1001", "job_name2"
start_message2 = ("start", job_id2, "log2.log", job_name2)
db_manager._confirm_submitted(index2, job_name2)
fname = await send_message(socket, start_message2)
assert fname == _ensure_str(fnames[1])

# Send a stop message to the DatabaseManager
stop_message = ("stop", fname)
reply = await send_message(socket, stop_message)
assert reply is None

# Check that the database is updated correctly
entry = db_manager._db.get(lambda entry: entry.fname == _ensure_str(fnames[1]))
assert entry is not None
assert entry.job_id == job_id2
assert entry.job_name == job_name2

with pytest.raises(zmq.error.Again, match="Resource temporarily unavailable"):
await send_message(socket, start_message)

Expand Down Expand Up @@ -365,18 +373,18 @@ async def test_database_manager_stop_request_and_requests(
db_manager._stop_request(fname1)
entry = db_manager._db.get(lambda entry: entry.fname == fname1)
assert entry is not None
assert entry.job_id is None, (fname1, fname2)
assert entry.job_id == job_id1, (fname1, fname2)
assert entry.is_done is True
assert entry.job_name is None
assert entry.job_name == job_name1

# Stop the job for learner2 using _stop_requests
db_manager._stop_requests([fname2])

entry = db_manager._db.get(lambda entry: entry.fname == fname2)
assert entry is not None
assert entry.job_id is None, (fname1, fname2)
assert entry.job_id == job_id2, (fname1, fname2)
assert entry.is_done is True
assert entry.job_name is None
assert entry.job_name == job_name2


def test_job_failure_after_start_request(db_manager: DatabaseManager) -> None:
Expand Down Expand Up @@ -566,9 +574,9 @@ async def test_dependencies(
db_manager._stop_request(fname1)
entry = db_manager._db.get(lambda entry: entry.fname == fname1)
assert entry is not None
assert entry.job_id is None
assert entry.job_id == job_id1
assert entry.is_done is True
assert entry.job_name is None
assert entry.job_name == job_name1

# Try getting a new job
_index2, _fname2 = db_manager._choose_fname()
Expand All @@ -589,9 +597,9 @@ async def test_dependencies(

entry = db_manager._db.get(lambda entry: entry.fname == fname2)
assert entry is not None
assert entry.job_id is None, (fname1, fname2)
assert entry.job_id == job_id2, (fname1, fname2)
assert entry.is_done is True
assert entry.job_name is None
assert entry.job_name == job_name2
with pytest.raises(
RuntimeError,
match="Requested a new job but no more learners to run in the database.",
Expand Down
4 changes: 4 additions & 0 deletions tests/test_kill_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_logs_with_string_or_condition_string_error(tmp_path: Path) -> None:
"job_name": "test_job",
"output_logs": [str(logs_file)],
"log_fname": "log_file.log",
"is_done": False,
},
]

Expand All @@ -89,6 +90,7 @@ def test_logs_with_string_or_condition_callable_error(tmp_path: Path) -> None:
"job_name": "test_job",
"output_logs": [str(logs_file)],
"log_fname": "log_file.log",
"is_done": False,
},
]

Expand All @@ -113,6 +115,7 @@ def test_logs_with_string_or_condition_no_error(tmp_path: Path) -> None:
"job_name": "test_job",
"output_logs": [str(logs_file)],
"log_fname": "log_file.log",
"is_done": False,
},
]

Expand All @@ -131,6 +134,7 @@ def test_logs_with_string_or_condition_missing_file() -> None:
"job_name": "test_job",
"output_logs": ["non_existent_file.txt"],
"log_fname": "log_file.log",
"is_done": False,
},
]

Expand Down
4 changes: 2 additions & 2 deletions tests/test_run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ async def test_run_manager_auto_restart(
db = rm.database_manager.as_dicts()
assert len(db) == 2
for entry in db:
assert entry["job_id"] is None
assert entry["job_name"] is None
assert entry["job_id"] in ("0", "1", "2")
assert entry["job_name"] in job_names
assert entry["is_done"]
assert entry["log_fname"].endswith(".log")

Expand Down
Loading