14
14
@click .argument ('inference_graph_uri' )
15
15
@click .argument ('label_map_uri' )
16
16
@click .argument ('projects_uri' )
17
+ @click .argument ('output_dir_uri' )
17
18
@click .option ('--mask-uri' , default = None ,
18
19
help = 'URI for mask GeoJSON file to use as filter for detections' )
19
20
@click .option ('--channel-order' , nargs = 3 , type = int ,
20
21
default = default_channel_order , help = 'Index of RGB channels' )
21
22
@click .option ('--chip-size' , default = 300 )
22
23
@click .option ('--score-thresh' , default = 0.5 ,
23
24
help = 'Score threshold of predictions to keep' )
24
- @click .option ('--merge-thresh' , default = 0.05 ,
25
+ @click .option ('--merge-thresh' , default = 0.5 ,
25
26
help = 'IOU threshold for merging predictions' )
26
27
@click .option ('--save-temp' , is_flag = True )
27
28
def predict_array (inference_graph_uri , label_map_uri , projects_uri ,
28
- mask_uri , channel_order , chip_size , score_thresh ,
29
- merge_thresh , save_temp ):
30
- job_index = int (os .environ ['AWS_BATCH_JOB_ARRAY_INDEX' ])
29
+ output_dir_uri , mask_uri , channel_order , chip_size ,
30
+ score_thresh , merge_thresh , save_temp ):
31
+ job_ind = int (os .environ ['AWS_BATCH_JOB_ARRAY_INDEX' ])
31
32
32
33
prefix = temp_root_dir
33
34
temp_dir = os .path .join (prefix , 'predict-array' ) if save_temp else None
34
35
with MyTemporaryDirectory (temp_dir , prefix ) as temp_dir :
35
36
projects_path = download_if_needed (projects_uri , temp_dir )
36
37
with open (projects_path , 'r' ) as projects_file :
37
38
projects = json .load (projects_file )
38
- if job_index >= len (projects ):
39
- raise ValueError (
40
- 'There are {} projects and job_index is {}!' .format (
41
- len (projects ), job_index ))
42
- project = projects [job_index ]
43
- image_uris = project ['images' ]
44
- output_uri = project ['annotations' ]
39
+
40
+ def get_image_coords ():
41
+ image_ind = 0
42
+ for project_ind , project in enumerate (projects ):
43
+ for project_image_ind , image_uri in enumerate (project ['images' ]):
44
+ if job_ind == image_ind :
45
+ return project_ind , project_image_ind
46
+ image_ind += 1
47
+
48
+ # Run predict for single image and generate ouput_uri based on
49
+ # the index of the image within the project.
50
+ project_ind , project_image_ind = get_image_coords ()
51
+ project = projects [project_ind ]
52
+ image_uri = project ['images' ][project_image_ind ]
53
+ image_uris = [image_uri ]
54
+ output_uri = os .path .join (
55
+ output_dir_uri , project ['id' ],
56
+ '{}.json' .format (project_image_ind ))
45
57
output_debug_uri = None
46
58
47
59
_predict (inference_graph_uri , label_map_uri , image_uris ,
@@ -50,6 +62,5 @@ def predict_array(inference_graph_uri, label_map_uri, projects_uri,
50
62
save_temp )
51
63
52
64
53
-
54
65
if __name__ == '__main__' :
55
66
predict_array ()
0 commit comments