@@ -155,6 +155,7 @@ async def manage_jobs(
155
155
run_script = "run_learner.py" ,
156
156
python_executable = None ,
157
157
interval = 30 ,
158
+ remote = None ,
158
159
* ,
159
160
max_simultaneous_jobs = 5000 ,
160
161
max_fails_per_job = 100 ,
@@ -164,7 +165,7 @@ async def manage_jobs(
164
165
with concurrent .futures .ProcessPoolExecutor () as ex :
165
166
while True :
166
167
try :
167
- running = queue ()
168
+ running = queue (remote = remote )
168
169
_update_db (db_fname , running ) # in case some jobs died
169
170
queued = {j ["name" ] for j in running .values () if j ["name" ] in job_names }
170
171
not_queued = set (job_names ) - queued
@@ -236,6 +237,7 @@ def start_job_manager(
236
237
run_script = "run_learner.py" ,
237
238
python_executable = None ,
238
239
interval = 30 ,
240
+ remote = None ,
239
241
* ,
240
242
max_simultaneous_jobs = 5000 ,
241
243
max_fails_per_job = 40 ,
@@ -250,6 +252,7 @@ def start_job_manager(
250
252
run_script ,
251
253
python_executable ,
252
254
interval ,
255
+ remote ,
253
256
max_simultaneous_jobs = max_simultaneous_jobs ,
254
257
max_fails_per_job = max_fails_per_job ,
255
258
)
@@ -356,6 +359,7 @@ async def manage_killer(
356
359
interval : int = 600 ,
357
360
max_cancel_tries : int = 5 ,
358
361
move_to : Optional [str ] = None ,
362
+ remote : Optional [str ] = None ,
359
363
) -> Coroutine :
360
364
# It seems like tasks that print the error message do not always stop working
361
365
# I think it only stops working when the error happens on a node where the logger runs.
@@ -371,7 +375,7 @@ async def manage_killer(
371
375
to_delete = []
372
376
373
377
# get cancel/delete only the processes/logs that are running nowg
374
- for job_id , info in queue ().items ():
378
+ for job_id , info in queue (remote = remote ).items ():
375
379
job_name = info ["name" ]
376
380
if job_id in failed_jobs .get (job_name , []):
377
381
to_cancel .append (job_name )
@@ -427,9 +431,10 @@ def start_kill_manager(
427
431
interval : int = 600 ,
428
432
max_cancel_tries : int = 5 ,
429
433
move_to : Optional [str ] = None ,
434
+ remote : Optional [str ] = None ,
430
435
) -> asyncio .Task :
431
436
ioloop = asyncio .get_event_loop ()
432
- coro = manage_killer (job_names , error , interval , max_cancel_tries , move_to )
437
+ coro = manage_killer (job_names , error , interval , max_cancel_tries , move_to , remote )
433
438
return ioloop .create_task (coro )
434
439
435
440
@@ -667,6 +672,7 @@ def __init__(
667
672
overwrite_db : bool = True ,
668
673
start_job_manager_kwargs : Optional [dict ] = None ,
669
674
start_kill_manager_kwargs : Optional [dict ] = None ,
675
+ remote : Optional [str ] = None ,
670
676
):
671
677
# Set from arguments
672
678
self .run_script = run_script
@@ -688,6 +694,7 @@ def __init__(
688
694
self .overwrite_db = overwrite_db
689
695
self .start_job_manager_kwargs = start_job_manager_kwargs or {}
690
696
self .start_kill_manager_kwargs = start_kill_manager_kwargs or {}
697
+ self .remote = remote
691
698
692
699
# Set in methods
693
700
self .job_task = None
@@ -794,6 +801,7 @@ def _start_job_manager(self) -> None:
794
801
interval = self .job_manager_interval ,
795
802
run_script = self .run_script ,
796
803
job_script_function = self .job_script_function ,
804
+ remote = self .remote ,
797
805
** self .start_job_manager_kwargs ,
798
806
)
799
807
@@ -808,6 +816,7 @@ def _start_kill_manager(self) -> None:
808
816
error = self .kill_on_error ,
809
817
interval = self .kill_interval ,
810
818
move_to = self .move_old_logs_to ,
819
+ remote = self .remote ,
811
820
** self .start_kill_manager_kwargs ,
812
821
)
813
822
@@ -818,7 +827,7 @@ def cancel(self):
818
827
self .database_task .cancel ()
819
828
if self .kill_task is not None :
820
829
self .kill_task .cancel ()
821
- return cancel (self .job_names )
830
+ return cancel (self .job_names , remote = self . remote )
822
831
823
832
def cleanup (self ):
824
833
"""Cleanup the log and batch files.
@@ -953,7 +962,11 @@ def cleanup(_):
953
962
)
954
963
955
964
def _info_html (self ):
956
- jobs = [job for job in queue ().values () if job ["name" ] in self .job_names ]
965
+ jobs = [
966
+ job
967
+ for job in queue (remote = self .remote ).values ()
968
+ if job ["name" ] in self .job_names
969
+ ]
957
970
n_running = sum (job ["state" ] in ("RUNNING" , "R" ) for job in jobs )
958
971
n_pending = sum (job ["state" ] in ("PENDING" , "Q" ) for job in jobs )
959
972
n_done = sum (job ["is_done" ] for job in self .get_database ())
0 commit comments