Skip to content

Commit f8b0ff8

Browse files
authored
Merge pull request #209 from azavea/lf/rv1-parallel
Add parallel prediction for old version of object detection
2 parents d73efdc + c5d9471 commit f8b0ff8

File tree

5 files changed

+152
-46
lines changed

5 files changed

+152
-46
lines changed

src/config-samples/detection/train/mobilenet.config

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ model {
141141

142142
train_config: {
143143
batch_size: 8
144+
num_steps: 40000 # Set to 0 for an indefinite number of steps
144145
optimizer {
145146
rms_prop_optimizer: {
146147
learning_rate: {

src/rv/detection/commands/merge_predictions.py

+27-31
Original file line numberDiff line numberDiff line change
@@ -11,46 +11,42 @@
1111
MyTemporaryDirectory)
1212

1313

14-
def get_annotations_paths(projects_path, temp_dir):
15-
annotations_paths = []
16-
with open(projects_path, 'r') as projects_file:
17-
projects = json.load(projects_file)
18-
for project_ind, project in enumerate(projects):
19-
annotations_uri = project['annotations']
20-
annotations_path = download_if_needed(
21-
annotations_uri, temp_dir)
22-
annotations_paths.append(annotations_path)
23-
return annotations_paths
24-
25-
26-
def merge_annotations(annotations_list):
27-
all_annotations = copy.deepcopy(annotations_list[0])
28-
for annotations in annotations_list[1:]:
29-
all_annotations['features'].extend(annotations['features'])
30-
return all_annotations
14+
def _merge_predictions(predictions_list):
15+
merged_predictions = copy.deepcopy(predictions_list[0])
16+
for predictions in predictions_list[1:]:
17+
merged_predictions['features'].extend(predictions['features'])
18+
return merged_predictions
3119

3220

3321
@click.command()
3422
@click.argument('projects_uri')
35-
@click.argument('output_uri')
23+
@click.argument('output_dir_uri')
3624
@click.option('--save-temp', is_flag=True)
37-
def merge_predictions(projects_uri, output_uri, save_temp):
25+
def merge_predictions(projects_uri, output_dir_uri, save_temp):
3826
prefix = temp_root_dir
3927
temp_dir = os.path.join(prefix, 'merge-predictions') if save_temp else None
4028
with MyTemporaryDirectory(temp_dir, prefix) as temp_dir:
4129
projects_path = download_if_needed(projects_uri, temp_dir)
42-
output_path = get_local_path(output_uri, temp_dir)
43-
44-
annotation_paths = get_annotations_paths(projects_path, temp_dir)
45-
annotations_list = []
46-
for annotation_path in annotation_paths:
47-
with open(annotation_path, 'r') as annotation_file:
48-
annotations_list.append(json.load(annotation_file))
49-
50-
annotations = merge_annotations(annotations_list)
51-
with open(output_path, 'w') as output_file:
52-
json.dump(annotations, output_file, indent=4)
53-
upload_if_needed(output_path, output_uri)
30+
31+
# For each project:
32+
# download the predictions files, merge them, and upload the merged
33+
# predictions.
34+
projects = json.load(open(projects_path))
35+
for project in projects:
36+
predictions_list = []
37+
for image_ind, image in enumerate(project['images']):
38+
predictions_uri = os.path.join(
39+
output_dir_uri, project['id'],
40+
'{}.json'.format(image_ind))
41+
predictions_path = download_if_needed(
42+
predictions_uri, temp_dir)
43+
predictions_list.append(json.load(open(predictions_path)))
44+
45+
output_uri = project['annotations']
46+
output_path = get_local_path(output_uri, temp_dir)
47+
predictions = _merge_predictions(predictions_list)
48+
json.dump(predictions, open(output_path, 'w'))
49+
upload_if_needed(output_path, output_uri)
5450

5551

5652
if __name__ == '__main__':

src/rv/detection/commands/predict_array.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,46 @@
1414
@click.argument('inference_graph_uri')
1515
@click.argument('label_map_uri')
1616
@click.argument('projects_uri')
17+
@click.argument('output_dir_uri')
1718
@click.option('--mask-uri', default=None,
1819
help='URI for mask GeoJSON file to use as filter for detections')
1920
@click.option('--channel-order', nargs=3, type=int,
2021
default=default_channel_order, help='Index of RGB channels')
2122
@click.option('--chip-size', default=300)
2223
@click.option('--score-thresh', default=0.5,
2324
help='Score threshold of predictions to keep')
24-
@click.option('--merge-thresh', default=0.05,
25+
@click.option('--merge-thresh', default=0.5,
2526
help='IOU threshold for merging predictions')
2627
@click.option('--save-temp', is_flag=True)
2728
def predict_array(inference_graph_uri, label_map_uri, projects_uri,
28-
mask_uri, channel_order, chip_size, score_thresh,
29-
merge_thresh, save_temp):
30-
job_index = int(os.environ['AWS_BATCH_JOB_ARRAY_INDEX'])
29+
output_dir_uri, mask_uri, channel_order, chip_size,
30+
score_thresh, merge_thresh, save_temp):
31+
job_ind = int(os.environ['AWS_BATCH_JOB_ARRAY_INDEX'])
3132

3233
prefix = temp_root_dir
3334
temp_dir = os.path.join(prefix, 'predict-array') if save_temp else None
3435
with MyTemporaryDirectory(temp_dir, prefix) as temp_dir:
3536
projects_path = download_if_needed(projects_uri, temp_dir)
3637
with open(projects_path, 'r') as projects_file:
3738
projects = json.load(projects_file)
38-
if job_index >= len(projects):
39-
raise ValueError(
40-
'There are {} projects and job_index is {}!'.format(
41-
len(projects), job_index))
42-
project = projects[job_index]
43-
image_uris = project['images']
44-
output_uri = project['annotations']
39+
40+
def get_image_coords():
41+
image_ind = 0
42+
for project_ind, project in enumerate(projects):
43+
for project_image_ind, image_uri in enumerate(project['images']):
44+
if job_ind == image_ind:
45+
return project_ind, project_image_ind
46+
image_ind += 1
47+
48+
# Run predict for single image and generate ouput_uri based on
49+
# the index of the image within the project.
50+
project_ind, project_image_ind = get_image_coords()
51+
project = projects[project_ind]
52+
image_uri = project['images'][project_image_ind]
53+
image_uris = [image_uri]
54+
output_uri = os.path.join(
55+
output_dir_uri, project['id'],
56+
'{}.json'.format(project_image_ind))
4557
output_debug_uri = None
4658

4759
_predict(inference_graph_uri, label_map_uri, image_uris,
@@ -50,6 +62,5 @@ def predict_array(inference_graph_uri, label_map_uri, projects_uri,
5062
save_temp)
5163

5264

53-
5465
if __name__ == '__main__':
5566
predict_array()

src/rv/detection/commands/train.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from os.path import join, dirname, splitext
1+
from os.path import join, basename
22
import os
33
from subprocess import Popen
44
import zipfile
55
from threading import Timer
66
from urllib.parse import urlparse
7+
import glob
8+
import re
79

810
import click
911

@@ -14,6 +16,30 @@
1416
from rv.detection.commands.settings import temp_root_dir
1517

1618

19+
def get_last_checkpoint_path(train_root_dir):
20+
index_paths = glob.glob(join(train_root_dir, 'train', '*.index'))
21+
checkpoint_ids = []
22+
for index_path in index_paths:
23+
match = re.match(r'model.ckpt-(\d+).index', basename(index_path))
24+
checkpoint_ids.append(int(match.group(1)))
25+
checkpoint_id = max(checkpoint_ids)
26+
checkpoint_path = join(
27+
train_root_dir, 'train', 'model.ckpt-{}'.format(checkpoint_id))
28+
return checkpoint_path
29+
30+
31+
def export_inference_graph(train_root_dir, config_path, inference_graph_path):
32+
checkpoint_path = get_last_checkpoint_path(train_root_dir)
33+
print('Exporting checkpoint {}...'.format(checkpoint_path))
34+
train_process = Popen([
35+
'python', '/opt/src/tf/object_detection/export_inference_graph.py',
36+
'--input_type', 'image_tensor',
37+
'--pipeline_config_path', config_path,
38+
'--checkpoint_path', checkpoint_path,
39+
'--inference_graph_path', inference_graph_path])
40+
train_process.wait()
41+
42+
1743
@click.command()
1844
@click.argument('config_uri')
1945
@click.argument('train_dataset_uri')
@@ -41,6 +67,7 @@ def train(config_uri, train_dataset_uri, val_dataset_uri, model_checkpoint_uri,
4167
make_dir(train_root_dir)
4268
train_dir = join(train_root_dir, 'train')
4369
eval_dir = join(train_root_dir, 'eval')
70+
inference_graph_path = join(train_root_dir, 'inference_graph.pb')
4471

4572
def process_zip_file(uri, temp_dir, link_dir):
4673
if uri.endswith('.zip'):
@@ -84,9 +111,15 @@ def sync_train_dir(delete=True):
84111
'tensorboard', '--logdir={}'.format(train_root_dir)],
85112
preexec_fn=on_parent_exit('SIGTERM'))
86113

114+
# After training finishes due to num_steps exceeded,
115+
# kill monitor processes, export inference graph, and upload.
87116
train_process.wait()
88-
eval_process.wait()
89-
tensorboard_process.wait()
117+
eval_process.kill()
118+
tensorboard_process.kill()
119+
export_inference_graph(
120+
train_root_dir, config_path, inference_graph_path)
121+
if urlparse(train_uri).scheme == 's3':
122+
sync_dir(train_root_dir, train_uri, delete=True)
90123

91124

92125
if __name__ == '__main__':
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import json
2+
import os
3+
import click
4+
5+
from rv.utils.files import (
6+
download_if_needed, MyTemporaryDirectory)
7+
from rv.utils.batch import _batch_submit
8+
from rv.detection.commands.settings import temp_root_dir
9+
10+
11+
def make_predict_array_cmd(inference_graph_uri, label_map_uri, projects_uri,
12+
output_dir_uri):
13+
return 'python -m rv.detection.run predict_array {} {} {} {}'.format(
14+
inference_graph_uri, label_map_uri, projects_uri, output_dir_uri)
15+
16+
17+
def make_merge_predictions_cmd(projects_uri, output_dir_uri):
18+
return 'python -m rv.detection.run merge_predictions {} {}'.format(
19+
projects_uri, output_dir_uri)
20+
21+
22+
def get_nb_images(projects):
23+
nb_images = 0
24+
for project in projects:
25+
nb_images += len(project['images'])
26+
return nb_images
27+
28+
29+
@click.command()
30+
@click.argument('projects_uri')
31+
@click.argument('label_map_uri')
32+
@click.argument('inference_graph_uri')
33+
@click.argument('output_dir_uri')
34+
@click.option('--branch-name', default='develop')
35+
@click.option('--attempts', default=1)
36+
@click.option('--cpu', is_flag=True)
37+
def parallel_predict(projects_uri, label_map_uri, inference_graph_uri,
38+
output_dir_uri,
39+
branch_name, attempts, cpu):
40+
prefix = temp_root_dir
41+
temp_dir = os.path.join(prefix, 'parallel-predict')
42+
with MyTemporaryDirectory(temp_dir, prefix) as temp_dir:
43+
# Load projects and count number of images
44+
projects_path = download_if_needed(projects_uri, temp_dir)
45+
projects = json.load(open(projects_path))
46+
nb_images = get_nb_images(projects)
47+
48+
# Submit an array job with nb_images elements.
49+
command = make_predict_array_cmd(
50+
inference_graph_uri, label_map_uri, projects_uri, output_dir_uri)
51+
'''
52+
predict_job_id = _batch_submit(
53+
branch_name, command, attempts=attempts, cpu=cpu,
54+
array_size=nb_images)
55+
'''
56+
# Submit a dependent merge_predictions job.
57+
command = make_merge_predictions_cmd(
58+
projects_uri, output_dir_uri)
59+
_batch_submit(
60+
branch_name, command, attempts=attempts, cpu=cpu)
61+
#parent_job_ids=[predict_job_id])
62+
63+
64+
if __name__ == '__main__':
65+
parallel_predict()

0 commit comments

Comments
 (0)