@@ -75,6 +75,14 @@ def main():
75
75
type = int ,
76
76
default = 10
77
77
)
78
+ p_process_datasets .add_argument (
79
+ '-r' , '--random_seeds' , dest = 'RANDOM_SEEDS' ,
80
+ type = _random_seed_list ,
81
+ default = None ,
82
+ help = "Defines a list of random seeds. Must be comma separated "
83
+ "integers. Must be same length as <NUM_SPLITS>. If omitted will "
84
+ "default to randomized seeds."
85
+ )
78
86
79
87
p_all = command_parsers .add_parser (
80
88
"all" ,
@@ -116,6 +124,12 @@ def full_workflow(args):
116
124
117
125
def process_datasets (args ):
118
126
127
+ if args .RANDOM_SEEDS is not None and len (args .RANDOM_SEEDS ) != args .NUM_SPLITS :
128
+ sys .exit (
129
+ "<RANDOM_SEEDS> must contain same number of random seed values as "
130
+ "<NUM_SPLITS>."
131
+ )
132
+
119
133
120
134
local_path = args .WORKDIR .joinpath ('data_in_tmp' )
121
135
@@ -481,7 +495,10 @@ def split_data_sets(
481
495
split_type = args .SPLIT_TYPE
482
496
ratio = (8 ,1 ,1 )
483
497
stratify_by = None
484
- random_state = None
498
+ if args .RANDOM_SEEDS is not None :
499
+ random_seeds = args .RANDOM_SEEDS
500
+ else :
501
+ random_seeds = [None ] * args .NUM_SPLITS
485
502
486
503
for data_set in data_sets_info .keys ():
487
504
if data_sets [data_set ].experiments is not None :
@@ -525,7 +542,7 @@ def split_data_sets(
525
542
split_type = split_type ,
526
543
ratio = ratio ,
527
544
stratify_by = stratify_by ,
528
- random_state = random_state
545
+ random_state = random_seeds [ i ]
529
546
)
530
547
train_keys = (
531
548
splits [i ]
@@ -768,6 +785,17 @@ def _check_folder(path: Union[str, PathLike, Path]) -> Path:
768
785
769
786
return abs_path
770
787
788
+ def _random_seed_list (list : str ) -> list :
789
+
790
+ if not isinstance (list , str ):
791
+ raise TypeError (
792
+ f"'random_seed' must be of type str. Supplied argument is of type "
793
+ f"{ type (list )} ."
794
+ )
795
+ list_ = list .split (',' )
796
+ return [int (item ) for item in list_ ]
797
+
798
+
771
799
if __name__ == '__main__' :
772
800
try : main ()
773
801
except KeyboardInterrupt : pass
0 commit comments