@@ -892,10 +892,158 @@ def close(self):
892
892
pass
893
893
894
894
895
+ class PsijWorker (Worker ):
896
+ """A worker to execute tasks using PSI/J."""
897
+
898
+ def __init__ (self , subtype , ** kwargs ):
899
+ """
900
+ Initialize PsijWorker.
901
+
902
+ Parameters
903
+ ----------
904
+ subtype : str
905
+ Scheduler for PSI/J.
906
+ """
907
+ try :
908
+ import psij
909
+ except ImportError :
910
+ logger .critical ("Please install psij." )
911
+ raise
912
+ logger .debug ("Initialize PsijWorker" )
913
+ self .psij = psij
914
+
915
+ # Check if the provided subtype is valid
916
+ valid_subtypes = ["local" , "slurm" ]
917
+ if subtype not in valid_subtypes :
918
+ raise ValueError (
919
+ f"Invalid 'subtype' provided. Available options: { ', ' .join (valid_subtypes )} "
920
+ )
921
+
922
+ self .subtype = subtype
923
+
924
+ def run_el (self , interface , rerun = False , ** kwargs ):
925
+ """Run a task."""
926
+ return self .exec_psij (interface , rerun = rerun )
927
+
928
+ def make_spec (self , cmd = None , arg = None ):
929
+ """
930
+ Create a PSI/J job specification.
931
+
932
+ Parameters
933
+ ----------
934
+ cmd : str, optional
935
+ Executable command. Defaults to None.
936
+ arg : list, optional
937
+ List of arguments. Defaults to None.
938
+
939
+ Returns
940
+ -------
941
+ psij.JobSpec
942
+ PSI/J job specification.
943
+ """
944
+ spec = self .psij .JobSpec ()
945
+ spec .executable = cmd
946
+ spec .arguments = arg
947
+
948
+ return spec
949
+
950
+ def make_job (self , spec , attributes ):
951
+ """
952
+ Create a PSI/J job.
953
+
954
+ Parameters
955
+ ----------
956
+ spec : psij.JobSpec
957
+ PSI/J job specification.
958
+ attributes : any
959
+ Job attributes.
960
+
961
+ Returns
962
+ -------
963
+ psij.Job
964
+ PSI/J job.
965
+ """
966
+ job = self .psij .Job ()
967
+ job .spec = spec
968
+ return job
969
+
970
+ async def exec_psij (self , runnable , rerun = False ):
971
+ """
972
+ Run a task (coroutine wrapper).
973
+
974
+ Raises
975
+ ------
976
+ Exception
977
+ If stderr is not empty.
978
+
979
+ Returns
980
+ -------
981
+ None
982
+ """
983
+ import pickle
984
+ from pathlib import Path
985
+
986
+ jex = self .psij .JobExecutor .get_instance (self .subtype )
987
+ absolute_path = Path (__file__ ).parent
988
+
989
+ if isinstance (runnable , TaskBase ):
990
+ cache_dir = runnable .cache_dir
991
+ file_path = cache_dir / "runnable_function.pkl"
992
+ with open (file_path , "wb" ) as file :
993
+ pickle .dump (runnable ._run , file )
994
+ func_path = absolute_path / "run_pickled.py"
995
+ spec = self .make_spec ("python" , [func_path , file_path ])
996
+ else : # it could be tuple that includes pickle files with tasks and inputs
997
+ cache_dir = runnable [- 1 ].cache_dir
998
+ file_path_1 = cache_dir / "taskmain.pkl"
999
+ file_path_2 = cache_dir / "ind.pkl"
1000
+ ind , task_main_pkl , task_orig = runnable
1001
+ with open (file_path_1 , "wb" ) as file :
1002
+ pickle .dump (task_main_pkl , file )
1003
+ with open (file_path_2 , "wb" ) as file :
1004
+ pickle .dump (ind , file )
1005
+ func_path = absolute_path / "run_pickled.py"
1006
+ spec = self .make_spec (
1007
+ "python" ,
1008
+ [
1009
+ func_path ,
1010
+ file_path_1 ,
1011
+ file_path_2 ,
1012
+ ],
1013
+ )
1014
+
1015
+ if rerun :
1016
+ spec .arguments .append ("--rerun" )
1017
+
1018
+ spec .stdout_path = cache_dir / "demo.stdout"
1019
+ spec .stderr_path = cache_dir / "demo.stderr"
1020
+
1021
+ job = self .make_job (spec , None )
1022
+ jex .submit (job )
1023
+ job .wait ()
1024
+
1025
+ if spec .stderr_path .stat ().st_size > 0 :
1026
+ with open (spec .stderr_path , "r" ) as stderr_file :
1027
+ stderr_contents = stderr_file .read ()
1028
+ raise Exception (
1029
+ f"stderr_path '{ spec .stderr_path } ' is not empty. Contents:\n { stderr_contents } "
1030
+ )
1031
+
1032
+ return
1033
+
1034
+ def close (self ):
1035
+ """Finalize the internal pool of tasks."""
1036
+ pass
1037
+
1038
+
895
1039
WORKERS = {
896
1040
"serial" : SerialWorker ,
897
1041
"cf" : ConcurrentFuturesWorker ,
898
1042
"slurm" : SlurmWorker ,
899
1043
"dask" : DaskWorker ,
900
1044
"sge" : SGEWorker ,
1045
+ ** {
1046
+ "psij-" + subtype : lambda subtype = subtype : PsijWorker (subtype = subtype )
1047
+ for subtype in ["local" , "slurm" ]
1048
+ },
901
1049
}
0 commit comments