2
2
3
3
from __future__ import annotations
4
4
5
+ import asyncio
5
6
import json
6
7
import pickle
7
8
from dataclasses import asdict , dataclass , field
@@ -194,6 +195,7 @@ def __init__(
194
195
self ._pickling_time : float | None = None
195
196
self ._total_learner_size : int | None = None
196
197
self ._db : SimpleDatabase | None = None
198
+ self ._trigger_event = asyncio .Event ()
197
199
198
200
def _setup (self ) -> None :
199
201
if self .db_fname .exists () and not self .overwrite_db :
@@ -214,7 +216,9 @@ def update(self, queue: dict[str, dict[str, str]] | None = None) -> None:
214
216
queue = self .scheduler .queue (me_only = True )
215
217
job_names_in_queue = [x ["job_name" ] for x in queue .values ()]
216
218
failed = self ._db .get_all (
217
- lambda e : e .job_name is not None and e .job_name not in job_names_in_queue , # type: ignore[operator]
219
+ lambda e : not e .is_done
220
+ and e .job_name is not None
221
+ and e .job_name not in job_names_in_queue , # type: ignore[operator]
218
222
)
219
223
self .failed .extend ([asdict (entry ) for _ , entry in failed ])
220
224
indices = [index for index , _ in failed ]
@@ -272,6 +276,12 @@ def _output_logs(self, job_id: str, job_name: str) -> list[Path]:
272
276
for f in output_fnames
273
277
]
274
278
279
+ def _done_but_still_running (self ) -> list [_DBEntry ]:
280
+ if self ._db is None :
281
+ return []
282
+ entries = self ._db .get_all (lambda e : e .is_done and e .job_id is not None )
283
+ return [entry for _ , entry in entries ]
284
+
275
285
def _choose_fname (self ) -> tuple [int , str | list [str ] | None ]:
276
286
assert self ._db is not None
277
287
entry = self ._db .get (
@@ -338,16 +348,17 @@ def _start_request(
338
348
339
349
def _stop_request (self , fname : str | list [str ] | Path | list [Path ]) -> None :
340
350
fname_str = _ensure_str (fname )
341
- reset = {"job_id" : None , " is_done" : True , "job_name" : None , "is_pending" : False }
351
+ reset = {"is_done" : True , "is_pending" : False }
342
352
assert self ._db is not None
343
353
entry_indices = [index for index , _ in self ._db .get_all (lambda e : e .fname == fname_str )]
344
354
self ._db .update (reset , entry_indices )
355
+ print (f"Done with { fname_str } " )
345
356
346
357
def _stop_requests (self , fnames : FnamesTypes ) -> None :
347
358
# Same as `_stop_request` but optimized for processing many `fnames` at once
348
359
assert self ._db is not None
349
360
fnames_str = {str (fname ) for fname in _ensure_str (fnames )}
350
- reset = {"job_id" : None , " is_done" : True , "job_name" : None , "is_pending" : False }
361
+ reset = {"is_done" : True , "is_pending" : False }
351
362
entry_indices = [
352
363
index for index , _ in self ._db .get_all (lambda e : str (e .fname ) in fnames_str )
353
364
]
@@ -381,7 +392,8 @@ def _dispatch(
381
392
if request_type == "stop" :
382
393
fname = request_arg [0 ] # workers send us the fname they were given
383
394
log .debug ("got a stop request" , fname = fname )
384
- self ._stop_request (fname ) # reset the job_id to None
395
+ self ._stop_request (fname ) # set is_done
396
+ self .trigger ()
385
397
return None
386
398
except Exception as e : # noqa: BLE001
387
399
return e
@@ -426,3 +438,7 @@ async def _manage(self) -> None:
426
438
break
427
439
finally :
428
440
socket .close ()
441
+
442
+ def trigger (self ) -> None :
443
+ """External method to trigger the _manage loop to continue."""
444
+ self ._trigger_event .set ()
0 commit comments