10
10
import time
11
11
import warnings
12
12
from contextlib import suppress
13
- from typing import Any , Callable , Coroutine , Dict , List , Optional , Union
13
+ from typing import Any , Callable , Coroutine , Dict , List , Optional , Tuple , Union
14
14
15
15
import adaptive
16
16
import dill
@@ -40,7 +40,7 @@ class MaxRestartsReached(Exception):
40
40
your Python code which results jobs being started indefinitely."""
41
41
42
42
43
- def _dispatch (request , db_fname ):
43
+ def _dispatch (request : Tuple [ str , str ], db_fname : str ):
44
44
request_type , request_arg = request
45
45
log .debug ("got a request" , request = request )
46
46
try :
@@ -53,7 +53,7 @@ def _dispatch(request, db_fname):
53
53
elif request_type == "stop" :
54
54
fname = request_arg # workers send us the fname they were given
55
55
log .debug ("got a stop request" , fname = fname )
56
- return _done_with_learner (db_fname , fname ) # reset the job_id to None
56
+ _done_with_learner (db_fname , fname ) # reset the job_id to None
57
57
except Exception as e :
58
58
return e
59
59
@@ -94,7 +94,7 @@ async def manage_database(url: str, db_fname: str) -> Coroutine:
94
94
)
95
95
96
96
97
- def start_database_manager (url : str , db_fname : str ):
97
+ def start_database_manager (url : str , db_fname : str ) -> asyncio . Task :
98
98
ioloop = asyncio .get_event_loop ()
99
99
coro = manage_database (url , db_fname )
100
100
return ioloop .create_task (coro )
@@ -148,17 +148,19 @@ def start_database_manager(url: str, db_fname: str):
148
148
149
149
150
150
async def manage_jobs (
151
- job_names ,
152
- db_fname ,
151
+ job_names : List [ str ] ,
152
+ db_fname : str ,
153
153
ioloop ,
154
154
cores = 8 ,
155
- job_script_function = make_job_script ,
156
- run_script = "run_learner.py" ,
157
- python_executable = None ,
158
- interval = 30 ,
155
+ job_script_function : Callable [
156
+ [str , int , str , Optional [str ]], str
157
+ ] = make_job_script ,
158
+ run_script : str = "run_learner.py" ,
159
+ python_executable : Optional [str ] = None ,
160
+ interval : int = 30 ,
159
161
* ,
160
- max_simultaneous_jobs = 5000 ,
161
- max_fails_per_job = 100 ,
162
+ max_simultaneous_jobs : int = 5000 ,
163
+ max_fails_per_job : int = 100 ,
162
164
) -> Coroutine :
163
165
n_started = 0
164
166
max_job_starts = max_fails_per_job * len (job_names )
@@ -230,16 +232,18 @@ async def manage_jobs(
230
232
231
233
232
234
def start_job_manager (
233
- job_names ,
234
- db_fname ,
235
- cores = 8 ,
236
- job_script_function = make_job_script ,
237
- run_script = "run_learner.py" ,
238
- python_executable = None ,
239
- interval = 30 ,
235
+ job_names : List [str ],
236
+ db_fname : str ,
237
+ cores : int = 8 ,
238
+ job_script_function : Callable [
239
+ [str , int , str , Optional [str ]], str
240
+ ] = make_job_script ,
241
+ run_script : str = "run_learner.py" ,
242
+ python_executable : Optional [str ] = None ,
243
+ interval : int = 30 ,
240
244
* ,
241
- max_simultaneous_jobs = 5000 ,
242
- max_fails_per_job = 40 ,
245
+ max_simultaneous_jobs : int = 5000 ,
246
+ max_fails_per_job : int = 40 ,
243
247
) -> asyncio .Task :
244
248
ioloop = asyncio .get_event_loop ()
245
249
coro = manage_jobs (
@@ -275,7 +279,7 @@ def _start_job(name, cores, job_script_function, run_script, python_executable):
275
279
time .sleep (0.5 )
276
280
277
281
278
- def get_allowed_url ():
282
+ def get_allowed_url () -> str :
279
283
"""Get an allowed url for the database manager.
280
284
281
285
Returns
@@ -289,7 +293,7 @@ def get_allowed_url():
289
293
return f"tcp://{ ip } :{ port } "
290
294
291
295
292
- def create_empty_db (db_fname : str , fnames : List [str ]):
296
+ def create_empty_db (db_fname : str , fnames : List [str ]) -> None :
293
297
"""Create an empty database that keeps track of fname -> (job_id, is_done).
294
298
295
299
Parameters
@@ -312,14 +316,14 @@ def get_database(db_fname: str) -> List[Dict[str, Any]]:
312
316
return db .all ()
313
317
314
318
315
- def _update_db (db_fname : str , running : Dict [str , dict ]):
319
+ def _update_db (db_fname : str , running : Dict [str , dict ]) -> None :
316
320
"""If the job_id isn't running anymore, replace it with None."""
317
321
with TinyDB (db_fname ) as db :
318
322
doc_ids = [entry .doc_id for entry in db .all () if entry ["job_id" ] not in running ]
319
323
db .update ({"job_id" : None }, doc_ids = doc_ids )
320
324
321
325
322
- def _choose_fname (db_fname : str , job_id : str ):
326
+ def _choose_fname (db_fname : str , job_id : str ) -> str :
323
327
Entry = Query ()
324
328
with TinyDB (db_fname ) as db :
325
329
if db .contains (Entry .job_id == job_id ):
@@ -339,13 +343,13 @@ def _choose_fname(db_fname: str, job_id: str):
339
343
return entry ["fname" ]
340
344
341
345
342
- def _done_with_learner (db_fname : str , fname : str ):
346
+ def _done_with_learner (db_fname : str , fname : str ) -> None :
343
347
Entry = Query ()
344
348
with TinyDB (db_fname ) as db :
345
349
db .update ({"job_id" : None , "is_done" : True }, Entry .fname == fname )
346
350
347
351
348
- def _get_n_jobs_done (db_fname : str ):
352
+ def _get_n_jobs_done (db_fname : str ) -> int :
349
353
Entry = Query ()
350
354
with TinyDB (db_fname ) as db :
351
355
return db .count (Entry .is_done == True ) # noqa: E711
@@ -440,14 +444,14 @@ def start_kill_manager(
440
444
441
445
442
446
def _make_default_run_script (
443
- url ,
444
- learners_file ,
445
- save_interval ,
446
- log_interval ,
447
- goal = None ,
448
- runner_kwargs = None ,
449
- run_script_fname = "run_learner.py" ,
450
- executor_type = "mpi4py" ,
447
+ url : str ,
448
+ learners_file : str ,
449
+ save_interval : int ,
450
+ log_interval : int ,
451
+ goal : Optional [ Callable [[ adaptive . BaseLearner ], bool ]] = None ,
452
+ runner_kwargs : Optional [ Dict [ str , Any ]] = None ,
453
+ run_script_fname : str = "run_learner.py" ,
454
+ executor_type : str = "mpi4py" ,
451
455
):
452
456
default_runner_kwargs = dict (shutdown_executor = True )
453
457
runner_kwargs = dict (default_runner_kwargs , goal = goal , ** (runner_kwargs or {}))
@@ -668,8 +672,8 @@ def __init__(
668
672
log_file_folder : str = "" ,
669
673
db_fname : str = "running.json" ,
670
674
overwrite_db : bool = True ,
671
- start_job_manager_kwargs : Optional [dict ] = None ,
672
- start_kill_manager_kwargs : Optional [dict ] = None ,
675
+ start_job_manager_kwargs : Optional [Dict [ str , Any ] ] = None ,
676
+ start_kill_manager_kwargs : Optional [Dict [ str , Any ] ] = None ,
673
677
):
674
678
# Set from arguments
675
679
self .run_script = run_script
@@ -814,16 +818,16 @@ def _start_kill_manager(self) -> None:
814
818
** self .start_kill_manager_kwargs ,
815
819
)
816
820
817
- def cancel (self ):
821
+ def cancel (self ) -> None :
818
822
"""Cancel the manager tasks and the jobs in the queue."""
819
823
if self .job_task is not None :
820
824
self .job_task .cancel ()
821
825
self .database_task .cancel ()
822
826
if self .kill_task is not None :
823
827
self .kill_task .cancel ()
824
- return cancel (self .job_names )
828
+ cancel (self .job_names )
825
829
826
- def cleanup (self ):
830
+ def cleanup (self ) -> None :
827
831
"""Cleanup the log and batch files.
828
832
829
833
If the `RunManager` is not running, the ``run_script.py`` file
@@ -838,9 +842,9 @@ def cleanup(self):
838
842
running_job_ids = set (queue ().keys ())
839
843
if self .executor_type == "ipyparallel" :
840
844
_delete_old_ipython_profiles (running_job_ids )
841
- return cleanup_files (self .job_names , log_file_folder = self .log_file_folder )
845
+ cleanup_files (self .job_names , log_file_folder = self .log_file_folder )
842
846
843
- def parse_log_files (self , only_last = True ):
847
+ def parse_log_files (self , only_last : bool = True ):
844
848
"""Parse the log-files and convert it to a `~pandas.core.frame.DataFrame`.
845
849
846
850
Parameters
@@ -859,7 +863,7 @@ def parse_log_files(self, only_last=True):
859
863
self .job_names , only_last , self .db_fname , self .log_file_folder
860
864
)
861
865
862
- def task_status (self ):
866
+ def task_status (self ) -> None :
863
867
r"""Print the stack of the `asyncio.Task`\s."""
864
868
if self .job_task is not None :
865
869
self .job_task .print_stack ()
@@ -872,13 +876,13 @@ def get_database(self) -> List[Dict[str, Any]]:
872
876
"""Get the database as a list of dicts."""
873
877
return get_database (self .db_fname )
874
878
875
- def load_learners (self ):
879
+ def load_learners (self ) -> None :
876
880
"""Load the learners in parallel using `adaptive_scheduler.utils.load_parallel`."""
877
881
from adaptive_scheduler .utils import load_parallel
878
882
879
883
load_parallel (self .learners_module .learners , self .learners_module .fnames )
880
884
881
- def elapsed_time (self ):
885
+ def elapsed_time (self ) -> float :
882
886
"""Total time elapsed since the RunManager was started."""
883
887
if not self .is_started :
884
888
return 0
@@ -893,7 +897,7 @@ def elapsed_time(self):
893
897
end_time = time .time ()
894
898
return end_time - self .start_time
895
899
896
- def status (self ):
900
+ def status (self ) -> str :
897
901
"""Return the current status of the RunManager."""
898
902
if not self .is_started :
899
903
return "not yet started"
@@ -912,7 +916,7 @@ def status(self):
912
916
self .end_time = time .time ()
913
917
return status
914
918
915
- def info (self ):
919
+ def info (self ) -> None :
916
920
"""Display information about the `RunManager`.
917
921
918
922
Returns an interactive ipywidget that can be
@@ -958,7 +962,7 @@ def cleanup(_):
958
962
)
959
963
)
960
964
961
- def _info_html (self ):
965
+ def _info_html (self ) -> str :
962
966
jobs = [job for job in queue ().values () if job ["name" ] in self .job_names ]
963
967
n_running = sum (job ["state" ] in ("RUNNING" , "R" ) for job in jobs )
964
968
n_pending = sum (job ["state" ] in ("PENDING" , "Q" ) for job in jobs )
@@ -1002,5 +1006,5 @@ def _info_html(self):
1002
1006
</dl>
1003
1007
"""
1004
1008
1005
- def _repr_html_ (self ):
1009
+ def _repr_html_ (self ) -> None :
1006
1010
return self .info ()
0 commit comments