Skip to content

Commit e8d339a

Browse files
committed
Automatically trigger loop when possible
1 parent 6f3e6c3 commit e8d339a

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

adaptive_scheduler/_server_support/database_manager.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import json
67
import pickle
78
from dataclasses import asdict, dataclass, field
@@ -194,6 +195,7 @@ def __init__(
194195
self._pickling_time: float | None = None
195196
self._total_learner_size: int | None = None
196197
self._db: SimpleDatabase | None = None
198+
self._trigger_event = asyncio.Event()
197199

198200
def _setup(self) -> None:
199201
if self.db_fname.exists() and not self.overwrite_db:
@@ -214,7 +216,9 @@ def update(self, queue: dict[str, dict[str, str]] | None = None) -> None:
214216
queue = self.scheduler.queue(me_only=True)
215217
job_names_in_queue = [x["job_name"] for x in queue.values()]
216218
failed = self._db.get_all(
217-
lambda e: e.job_name is not None and e.job_name not in job_names_in_queue, # type: ignore[operator]
219+
lambda e: not e.is_done
220+
and e.job_name is not None
221+
and e.job_name not in job_names_in_queue, # type: ignore[operator]
218222
)
219223
self.failed.extend([asdict(entry) for _, entry in failed])
220224
indices = [index for index, _ in failed]
@@ -272,6 +276,12 @@ def _output_logs(self, job_id: str, job_name: str) -> list[Path]:
272276
for f in output_fnames
273277
]
274278

279+
def _done_but_still_running(self) -> list[_DBEntry]:
280+
if self._db is None:
281+
return []
282+
entries = self._db.get_all(lambda e: e.is_done and e.job_id is not None)
283+
return [entry for _, entry in entries]
284+
275285
def _choose_fname(self) -> tuple[int, str | list[str] | None]:
276286
assert self._db is not None
277287
entry = self._db.get(
@@ -338,16 +348,17 @@ def _start_request(
338348

339349
def _stop_request(self, fname: str | list[str] | Path | list[Path]) -> None:
340350
fname_str = _ensure_str(fname)
341-
reset = {"job_id": None, "is_done": True, "job_name": None, "is_pending": False}
351+
reset = {"is_done": True, "is_pending": False}
342352
assert self._db is not None
343353
entry_indices = [index for index, _ in self._db.get_all(lambda e: e.fname == fname_str)]
344354
self._db.update(reset, entry_indices)
355+
print(f"Done with {fname_str}")
345356

346357
def _stop_requests(self, fnames: FnamesTypes) -> None:
347358
# Same as `_stop_request` but optimized for processing many `fnames` at once
348359
assert self._db is not None
349360
fnames_str = {str(fname) for fname in _ensure_str(fnames)}
350-
reset = {"job_id": None, "is_done": True, "job_name": None, "is_pending": False}
361+
reset = {"is_done": True, "is_pending": False}
351362
entry_indices = [
352363
index for index, _ in self._db.get_all(lambda e: str(e.fname) in fnames_str)
353364
]
@@ -381,7 +392,8 @@ def _dispatch(
381392
if request_type == "stop":
382393
fname = request_arg[0] # workers send us the fname they were given
383394
log.debug("got a stop request", fname=fname)
384-
self._stop_request(fname) # reset the job_id to None
395+
self._stop_request(fname) # set is_done
396+
self.trigger()
385397
return None
386398
except Exception as e: # noqa: BLE001
387399
return e
@@ -426,3 +438,7 @@ async def _manage(self) -> None:
426438
break
427439
finally:
428440
socket.close()
441+
442+
def trigger(self) -> None:
443+
"""External method to trigger the _manage loop to continue."""
444+
self._trigger_event.set()

adaptive_scheduler/_server_support/job_manager.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def __init__(
174174
# Other attributes
175175
self.n_started = 0
176176
self._request_times: dict[str, str] = {}
177-
self._trigger_event = asyncio.Event()
178177

179178
# Command line launcher options
180179
self.save_dataframe = save_dataframe
@@ -221,7 +220,9 @@ async def _update_database_and_get_not_queued(
221220
n_done = self.database_manager.n_done()
222221
if n_done == len(self.job_names):
223222
return None # we are finished!
224-
n_to_schedule = max(0, len(not_queued) - n_done)
223+
n_done_but_running = len(self.database_manager._done_but_still_running())
224+
n_to_schedule = max(0, n_done_but_running + len(not_queued) - n_done)
225+
print(f"n_done={n_done}, {not_queued=}, {queued=}, {n_to_schedule=}")
225226
return queued, set(list(not_queued)[:n_to_schedule])
226227

227228
async def _start_new_jobs(
@@ -233,6 +234,7 @@ async def _start_new_jobs(
233234
len(not_queued),
234235
self.max_simultaneous_jobs - len(queued),
235236
)
237+
print(f"num_jobs_to_start={num_jobs_to_start}")
236238
for _ in range(num_jobs_to_start):
237239
index, fname = self.database_manager._choose_fname()
238240
if index == -1:
@@ -243,6 +245,7 @@ async def _start_new_jobs(
243245
log.debug(
244246
f"Starting `job_name={job_name}` with `index={index}` and `fname={fname}`",
245247
)
248+
print(f"Starting `job_name={job_name}` with `index={index}` and `fname={fname}`")
246249
await asyncio.to_thread(self.scheduler.start_job, job_name, index=index)
247250
self.database_manager._confirm_submitted(index, job_name)
248251

@@ -263,7 +266,7 @@ async def _manage(self) -> None:
263266
if await sleep_unless_task_is_done(
264267
self.database_manager.task, # type: ignore[arg-type]
265268
self.interval,
266-
self._trigger_event,
269+
self.database_manager._trigger_event,
267270
): # if true, we are done
268271
return
269272
except asyncio.CancelledError: # noqa: PERF203
@@ -283,10 +286,10 @@ async def _manage(self) -> None:
283286
if await sleep_unless_task_is_done(
284287
self.database_manager.task, # type: ignore[arg-type]
285288
5,
286-
self._trigger_event,
289+
self.database_manager._trigger_event,
287290
): # if true, we are done
288291
return
289292

290293
def trigger(self) -> None:
291294
"""External method to trigger the _manage loop to continue."""
292-
self._trigger_event.set()
295+
self.database_manager.trigger()

0 commit comments

Comments
 (0)