@@ -46,8 +46,17 @@ def __validate_file_ids(self):
4646 'Some file IDs do not belong to the project. Please provide only files from the same project.' )
4747
4848 def get_image_data (self , diffgram_file ):
49+ MAX_RETRIES = 10
4950 if hasattr (diffgram_file , 'image' ):
50- image = imread (diffgram_file .image .get ('url_signed' ))
51+ for i in range (0 , MAX_RETRIES ):
52+ try :
53+ image = imread (diffgram_file .image .get ('url_signed' ))
54+ break
55+ except Exception as e :
56+ if i < MAX_RETRIES :
57+ continue
58+ else :
59+ raise e
5160 return image
5261 else :
5362 raise Exception ('Pytorch datasets only support images. Please provide only file_ids from images' )
@@ -70,11 +79,18 @@ def get_file_instances(self, diffgram_file):
7079 sample ['x_max_list' ] = x_max_list
7180 sample ['y_min_list' ] = y_min_list
7281 sample ['y_max_list' ] = y_max_list
82+ else :
83+ sample ['x_min_list' ] = []
84+ sample ['x_max_list' ] = []
85+ sample ['y_min_list' ] = []
86+ sample ['y_max_list' ] = []
7387
7488 if 'polygon' in instance_types_in_file :
7589 has_poly = True
7690 mask_list = self .extract_masks_from_polygon (instance_list , diffgram_file )
7791 sample ['polygon_mask_list' ] = mask_list
92+ else :
93+ sample ['polygon_mask_list' ] = []
7894
7995 if len (instance_types_in_file ) > 2 and has_boxes and has_boxes :
8096 raise NotImplementedError (
@@ -83,6 +99,7 @@ def get_file_instances(self, diffgram_file):
8399
84100 label_id_list , label_name_list = self .extract_labels (instance_list )
85101 sample ['label_id_list' ] = label_id_list
102+ sample ['instance_types_in_file' ] = instance_types_in_file
86103 sample ['label_name_list' ] = label_name_list
87104
88105 return sample
0 commit comments