Skip to content

Commit 5ac3a3f

Browse files
committed
Automatically trigger loop when possible
1 parent 6f3e6c3 commit 5ac3a3f

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

adaptive_scheduler/_server_support/database_manager.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import json
67
import pickle
78
from dataclasses import asdict, dataclass, field
@@ -194,6 +195,7 @@ def __init__(
194195
self._pickling_time: float | None = None
195196
self._total_learner_size: int | None = None
196197
self._db: SimpleDatabase | None = None
198+
self._trigger_event = asyncio.Event()
197199

198200
def _setup(self) -> None:
199201
if self.db_fname.exists() and not self.overwrite_db:
@@ -214,7 +216,9 @@ def update(self, queue: dict[str, dict[str, str]] | None = None) -> None:
214216
queue = self.scheduler.queue(me_only=True)
215217
job_names_in_queue = [x["job_name"] for x in queue.values()]
216218
failed = self._db.get_all(
217-
lambda e: e.job_name is not None and e.job_name not in job_names_in_queue, # type: ignore[operator]
219+
lambda e: not e.is_done
220+
and e.job_name is not None
221+
and e.job_name not in job_names_in_queue, # type: ignore[operator]
218222
)
219223
self.failed.extend([asdict(entry) for _, entry in failed])
220224
indices = [index for index, _ in failed]
@@ -272,6 +276,12 @@ def _output_logs(self, job_id: str, job_name: str) -> list[Path]:
272276
for f in output_fnames
273277
]
274278

279+
def _done_but_still_running(self) -> list[_DBEntry]:
280+
if self._db is None:
281+
return []
282+
entries = self._db.get_all(lambda e: e.is_done and e.job_id is not None)
283+
return [entry for _, entry in entries]
284+
275285
def _choose_fname(self) -> tuple[int, str | list[str] | None]:
276286
assert self._db is not None
277287
entry = self._db.get(
@@ -338,16 +348,17 @@ def _start_request(
338348

339349
def _stop_request(self, fname: str | list[str] | Path | list[Path]) -> None:
340350
fname_str = _ensure_str(fname)
341-
reset = {"job_id": None, "is_done": True, "job_name": None, "is_pending": False}
351+
reset = {"is_done": True, "is_pending": False}
342352
assert self._db is not None
343353
entry_indices = [index for index, _ in self._db.get_all(lambda e: e.fname == fname_str)]
344354
self._db.update(reset, entry_indices)
355+
print(f"Done with {fname_str}")
345356

346357
def _stop_requests(self, fnames: FnamesTypes) -> None:
347358
# Same as `_stop_request` but optimized for processing many `fnames` at once
348359
assert self._db is not None
349360
fnames_str = {str(fname) for fname in _ensure_str(fnames)}
350-
reset = {"job_id": None, "is_done": True, "job_name": None, "is_pending": False}
361+
reset = {"is_done": True, "is_pending": False}
351362
entry_indices = [
352363
index for index, _ in self._db.get_all(lambda e: str(e.fname) in fnames_str)
353364
]
@@ -381,7 +392,8 @@ def _dispatch(
381392
if request_type == "stop":
382393
fname = request_arg[0] # workers send us the fname they were given
383394
log.debug("got a stop request", fname=fname)
384-
self._stop_request(fname) # reset the job_id to None
395+
self._stop_request(fname) # set is_done
396+
self.trigger_scheduling_event()
385397
return None
386398
except Exception as e: # noqa: BLE001
387399
return e
@@ -426,3 +438,7 @@ async def _manage(self) -> None:
426438
break
427439
finally:
428440
socket.close()
441+
442+
def trigger_scheduling_event(self) -> None:
443+
"""External method to trigger the _manage loop to continue."""
444+
self._trigger_event.set()

adaptive_scheduler/_server_support/job_manager.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def __init__(
174174
# Other attributes
175175
self.n_started = 0
176176
self._request_times: dict[str, str] = {}
177-
self._trigger_event = asyncio.Event()
178177

179178
# Command line launcher options
180179
self.save_dataframe = save_dataframe
@@ -221,7 +220,8 @@ async def _update_database_and_get_not_queued(
221220
n_done = self.database_manager.n_done()
222221
if n_done == len(self.job_names):
223222
return None # we are finished!
224-
n_to_schedule = max(0, len(not_queued) - n_done)
223+
n_done_but_running = len(self.database_manager._done_but_still_running())
224+
n_to_schedule = max(0, n_done_but_running + len(not_queued) - n_done)
225225
return queued, set(list(not_queued)[:n_to_schedule])
226226

227227
async def _start_new_jobs(
@@ -233,6 +233,7 @@ async def _start_new_jobs(
233233
len(not_queued),
234234
self.max_simultaneous_jobs - len(queued),
235235
)
236+
print(f"num_jobs_to_start={num_jobs_to_start}")
236237
for _ in range(num_jobs_to_start):
237238
index, fname = self.database_manager._choose_fname()
238239
if index == -1:
@@ -263,7 +264,7 @@ async def _manage(self) -> None:
263264
if await sleep_unless_task_is_done(
264265
self.database_manager.task, # type: ignore[arg-type]
265266
self.interval,
266-
self._trigger_event,
267+
self.database_manager._trigger_event,
267268
): # if true, we are done
268269
return
269270
except asyncio.CancelledError: # noqa: PERF203
@@ -283,10 +284,10 @@ async def _manage(self) -> None:
283284
if await sleep_unless_task_is_done(
284285
self.database_manager.task, # type: ignore[arg-type]
285286
5,
286-
self._trigger_event,
287+
self.database_manager._trigger_event,
287288
): # if true, we are done
288289
return
289290

290291
def trigger(self) -> None:
291292
"""External method to trigger the _manage loop to continue."""
292-
self._trigger_event.set()
293+
self.database_manager.trigger_scheduling_event()

0 commit comments

Comments
 (0)