|
5 | 5 | from worker.agent_storage import AgentStorage
|
6 | 6 | from worker.fs_storages import EmptyStorage
|
7 | 7 | from worker import constants
|
8 |
| -from collections import defaultdict |
| 8 | + |
| 9 | + |
| 10 | +def _maybe_append_image_extension(name, ext): |
| 11 | + name_split = os.path.splitext(name) |
| 12 | + if name_split[1] == '': |
| 13 | + normalized_ext = ('.' + ext).replace('..', '.') |
| 14 | + sly.image.validate_ext(normalized_ext) |
| 15 | + return name + normalized_ext |
| 16 | + else: |
| 17 | + return name |
9 | 18 |
|
10 | 19 |
|
11 | 20 | class DataManager(object):
|
@@ -34,10 +43,10 @@ def download_nn(self, name, parent_dir):
|
34 | 43 | self.logger.info('NN has been copied from local storage.')
|
35 | 44 | return
|
36 | 45 |
|
37 |
| - model_in_mb = int(float(model_info.size) / 1024 / 1024) |
38 |
| - progress = sly.Progress('Download NN: {!r}'.format(name), model_in_mb) |
| 46 | + model_in_mb = int(float(model_info.size) / 1024 / 1024 + 1) |
| 47 | + progress = sly.Progress('Download NN: {!r}'.format(name), model_in_mb, self.logger) |
39 | 48 |
|
40 |
| - self.public_api.model.download_to_dir(self.workspace_id, name, parent_dir, progress.iter_done_report) |
| 49 | + self.public_api.model.download_to_dir(self.workspace_id, name, parent_dir, progress.iters_done_report) |
41 | 50 | self.logger.info('NN has been downloaded from server.')
|
42 | 51 |
|
43 | 52 | if self.has_nn_storage():
|
@@ -87,34 +96,37 @@ def download_dataset(self, dataset, dataset_id):
|
87 | 96 | 'images_to_download': len(images_to_download)})
|
88 | 97 | if len(images_to_download) + len(images_in_cache) != len(images):
|
89 | 98 | raise RuntimeError("Error with images cache during download. Please contact support.")
|
90 |
| - for batch_cache in sly.batched(list(zip(images_in_cache, images_cache_paths)), constants.BATCH_SIZE_GET_IMAGES_INFO()): |
91 |
| - img_cache_ids = [img_info.id for img_info, _ in batch_cache] |
| 99 | + |
| 100 | + if len(images_in_cache) > 0: |
| 101 | + img_cache_ids = [img_info.id for img_info in images_in_cache] |
92 | 102 | ann_info_list = self.public_api.annotation.download_batch(dataset_id, img_cache_ids, progress_anns.iters_done_report)
|
93 |
| - img_name_to_ann = {ann.image_name: ann.annotation for ann in ann_info_list} |
94 |
| - for img_info, img_cache_path in batch_cache: |
95 |
| - dataset.add_item_file(img_info.name, img_cache_path, img_name_to_ann[img_info.name]) |
| 103 | + img_name_to_ann = {ann.image_id: ann.annotation for ann in ann_info_list} |
| 104 | + for img_info, img_cache_path in zip(images_in_cache, images_cache_paths): |
| 105 | + item_name = _maybe_append_image_extension(img_info.name, img_info.ext) |
| 106 | + dataset.add_item_file(item_name, img_cache_path, img_name_to_ann[img_info.id]) |
96 | 107 | progress_imgs.iter_done_report()
|
97 | 108 |
|
98 | 109 | # download images from server
|
99 |
| - for batch_download in sly.batched(images_to_download, constants.BATCH_SIZE_GET_IMAGES_INFO()): |
| 110 | + if len(images_to_download) > 0: |
100 | 111 | #prepare lists for api methods
|
101 | 112 | img_ids = []
|
102 | 113 | img_paths = []
|
103 |
| - for img_info in batch_download: |
| 114 | + for img_info in images_to_download: |
104 | 115 | img_ids.append(img_info.id)
|
105 | 116 | # TODO download to a temp file and use dataset api to add the image to the dataset.
|
106 |
| - img_paths.append(dataset.deprecated_make_img_path(img_info.name, img_info.ext)) |
| 117 | + img_paths.append( |
| 118 | + os.path.join(dataset.img_dir, _maybe_append_image_extension(img_info.name, img_info.ext))) |
107 | 119 |
|
108 | 120 | # download annotations
|
109 | 121 | ann_info_list = self.public_api.annotation.download_batch(dataset_id, img_ids, progress_anns.iters_done_report)
|
110 |
| - img_name_to_ann = {ann.image_name: ann.annotation for ann in ann_info_list} |
111 |
| - self.public_api.image.download_batch(dataset_id, img_ids, img_paths, progress_imgs.iters_done_report) |
112 |
| - for img_info, img_path in zip(batch_download, img_paths): |
113 |
| - dataset.add_item_file(img_info.name, img_path, img_name_to_ann[img_info.name]) |
| 122 | + img_name_to_ann = {ann.image_id: ann.annotation for ann in ann_info_list} |
| 123 | + self.public_api.image.download_paths(dataset_id, img_ids, img_paths, progress_imgs.iters_done_report) |
| 124 | + for img_info, img_path in zip(images_to_download, img_paths): |
| 125 | + dataset.add_item_file(img_info.name, img_path, img_name_to_ann[img_info.id]) |
114 | 126 |
|
115 | 127 | if self.has_images_storage():
|
116 | 128 | progress_cache = sly.Progress('Dataset {!r}: cache images'.format(dataset.name), len(img_paths), self.logger)
|
117 |
| - img_hashes = [img_info.hash for img_info in batch_download] |
| 129 | + img_hashes = [img_info.hash for img_info in images_to_download] |
118 | 130 | self.storage.images.write_objects(img_paths, img_hashes, progress_cache.iter_done_report)
|
119 | 131 |
|
120 | 132 | # @TODO: remove legacy stuff
|
@@ -146,44 +158,31 @@ def upload_project(self, parent_dir, project_name, new_title, legacy=False, add_
|
146 | 158 | self.logger.info('PROJECT_CREATED',extra={'event_type': sly.EventType.PROJECT_CREATED, 'project_id': project_id})
|
147 | 159 |
|
148 | 160 | def upload_dataset(self, dataset, dataset_id):
|
149 |
| - progress = None |
| 161 | + progress_cache = None |
150 | 162 | items_count = len(dataset)
|
151 |
| - hash_to_img_paths = defaultdict(list) |
152 |
| - hash_to_ann_paths = defaultdict(list) |
153 |
| - hash_to_item_names = defaultdict(list) |
| 163 | + |
| 164 | + item_names = [] |
| 165 | + img_paths = [] |
| 166 | + ann_paths = [] |
154 | 167 | for item_name in dataset:
|
| 168 | + item_names.append(item_name) |
155 | 169 | item_paths = dataset.get_item_paths(item_name)
|
156 |
| - img_hash = sly.fs.get_file_hash(item_paths.img_path) |
157 |
| - hash_to_img_paths[img_hash].append(item_paths.img_path) |
158 |
| - hash_to_ann_paths[img_hash].append(item_paths.ann_path) |
159 |
| - hash_to_item_names[img_hash].append(item_name) |
| 170 | + img_paths.append(item_paths.img_path) |
| 171 | + ann_paths.append(item_paths.ann_path) |
| 172 | + |
160 | 173 | if self.has_images_storage():
|
161 |
| - if progress is None: |
162 |
| - progress = sly.Progress('Dataset {!r}: cache images'.format(dataset.name), items_count, self.logger) |
| 174 | + if progress_cache is None: |
| 175 | + progress_cache = sly.Progress('Dataset {!r}: cache images'.format(dataset.name), items_count, self.logger) |
| 176 | + |
| 177 | + img_hash = sly.fs.get_file_hash(item_paths.img_path) |
163 | 178 | self.storage.images.write_object(item_paths.img_path, img_hash)
|
164 |
| - progress.iter_done_report() |
| 179 | + progress_cache.iter_done_report() |
| 180 | + |
| 181 | + progress = sly.Progress('Dataset {!r}: upload images'.format(dataset.name), items_count, self.logger) |
| 182 | + image_infos = self.public_api.image.upload_paths(dataset_id, item_names, img_paths, progress.iters_done_report) |
165 | 183 |
|
166 |
| - progress_img = sly.Progress('Dataset {!r}: upload images'.format(dataset.name), items_count, self.logger) |
167 |
| - progress_ann = sly.Progress('Dataset {!r}: upload annotations'.format(dataset.name), items_count, self.logger) |
168 |
| - |
169 |
| - def add_images_annotations(hashes, pb_img_cb, pb_ann_cb): |
170 |
| - names = [name for hash in hashes for name in hash_to_item_names[hash]] |
171 |
| - unrolled_hashes = [hash for hash in hashes for _ in range(len(hash_to_item_names[hash]))] |
172 |
| - ann_paths = [path for hash in hashes for path in hash_to_ann_paths[hash]] |
173 |
| - remote_infos = self.public_api.image.add_batch(dataset_id, names, unrolled_hashes, pb_img_cb) |
174 |
| - self.public_api.annotation.upload_batch_paths(dataset_id, [info.id for info in remote_infos], ann_paths, pb_ann_cb) |
175 |
| - |
176 |
| - # add already uploaded images + attach annotations |
177 |
| - remote_hashes = self.public_api.image.check_existing_hashes(list(hash_to_img_paths.keys())) |
178 |
| - if len(remote_hashes) > 0: |
179 |
| - add_images_annotations(remote_hashes, progress_img.iters_done_report, progress_ann.iters_done_report) |
180 |
| - |
181 |
| - # upload new images + add annotations |
182 |
| - new_hashes = list(set(hash_to_img_paths.keys()) - set(remote_hashes)) |
183 |
| - img_paths = [path for hash in new_hashes for path in hash_to_img_paths[hash]] |
184 |
| - self.public_api.image.upload_batch_paths(img_paths, progress_img.iters_done_report) |
185 |
| - if len(new_hashes) > 0: |
186 |
| - add_images_annotations(new_hashes, None, progress_ann.iters_done_report) |
| 184 | + progress = sly.Progress('Dataset {!r}: upload annotations'.format(dataset.name), items_count, self.logger) |
| 185 | + self.public_api.annotation.upload_paths([info.id for info in image_infos], ann_paths, progress.iters_done_report) |
187 | 186 |
|
188 | 187 | def upload_archive(self, task_id, dir_to_archive, archive_name):
|
189 | 188 | self.logger.info("PACK_TO_ARCHIVE ...")
|
|
0 commit comments