Skip to content

Commit e5d323e

Browse files
committed
added ability to define list of random_seeds on the commandline for reproducability of splits
1 parent f3a5218 commit e5d323e

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

scripts/prepare_data_for_improve.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ def main():
7575
type=int,
7676
default=10
7777
)
78+
p_process_datasets.add_argument(
79+
'-r', '--random_seeds', dest='RANDOM_SEEDS',
80+
type=_random_seed_list,
81+
default=None,
82+
help="Defines a list of random seeds. Must be comma separated "
83+
"integers. Must be same length as <NUM_SPLITS>. If omitted will "
84+
"default to randomized seeds."
85+
)
7886

7987
p_all = command_parsers.add_parser(
8088
"all",
@@ -116,6 +124,12 @@ def full_workflow(args):
116124

117125
def process_datasets(args):
118126

127+
if args.RANDOM_SEEDS is not None and len(args.RANDOM_SEEDS) != args.NUM_SPLITS:
128+
sys.exit(
129+
"<RANDOM_SEEDS> must contain same number of random seed values as "
130+
"<NUM_SPLITS>."
131+
)
132+
119133

120134
local_path = args.WORKDIR.joinpath('data_in_tmp')
121135

@@ -481,7 +495,10 @@ def split_data_sets(
481495
split_type = args.SPLIT_TYPE
482496
ratio = (8,1,1)
483497
stratify_by = None
484-
random_state = None
498+
if args.RANDOM_SEEDS is not None:
499+
random_seeds = args.RANDOM_SEEDS
500+
else:
501+
random_seeds = [None] * args.NUM_SPLITS
485502

486503
for data_set in data_sets_info.keys():
487504
if data_sets[data_set].experiments is not None:
@@ -525,7 +542,7 @@ def split_data_sets(
525542
split_type=split_type,
526543
ratio=ratio,
527544
stratify_by=stratify_by,
528-
random_state=random_state
545+
random_state=random_seeds[i]
529546
)
530547
train_keys = (
531548
splits[i]
@@ -768,6 +785,17 @@ def _check_folder(path: Union[str, PathLike, Path]) -> Path:
768785

769786
return abs_path
770787

788+
def _random_seed_list(list: str) -> list:
789+
790+
if not isinstance(list, str):
791+
raise TypeError(
792+
f"'random_seed' must be of type str. Supplied argument is of type "
793+
f"{type(list)}."
794+
)
795+
list_ = list.split(',')
796+
return [int(item) for item in list_]
797+
798+
771799
if __name__ == '__main__':
772800
try: main()
773801
except KeyboardInterrupt: pass

0 commit comments

Comments
 (0)