Skip to content

Commit 07fed86

Browse files
committed
Add code for inference
1 parent 5dd7c6b commit 07fed86

4 files changed

+103
-0
lines changed

images/2008_002778_ar0.5.jpg

255 KB
Loading

images/2008_002778_ar1.0.jpg

202 KB
Loading

images/2008_002778_ar2.0.jpg

243 KB
Loading

inference.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Thu Jul 26 19:23:17 2018
4+
5+
@author: sakurai
6+
"""
7+
import os
8+
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
12+
import chainer
13+
from chainer.dataset.download import get_dataset_directory, cached_download
14+
15+
from evaluation import stretch
16+
from evaluation import TransformTestdata
17+
import main_resnet
18+
19+
20+
class Restorer(object):
21+
22+
_model_url = (
23+
'https://github.com/ronekko/model_parameters/releases/download/v1/'
24+
'resnet101-20171206T000238-4c22664.npz')
25+
26+
def __init__(self):
27+
self.net = main_resnet.Resnet(
28+
32, [3, 4, 5, 6], [64, 128, 256, 512], False)
29+
30+
npz_path = download_model(
31+
self._model_url, 'aspect_ratio_restorer')
32+
chainer.serializers.load_npz(npz_path, self.net)
33+
self.preprocessor = Preprocessor()
34+
35+
def predict_aspect_ratio(self, image):
36+
preprocessed_array = self.preprocessor.apply(image)
37+
with chainer.no_backprop_mode():
38+
with chainer.using_config('train', False):
39+
log_ar = self.net(preprocessed_array).array
40+
ar = self.net.xp.exp(log_ar)
41+
return ar[0, 0]
42+
43+
def restore_image(self, image, return_aspect_ratio=True):
44+
ar = self.predict_aspect_ratio(image)
45+
image = image.transpose(2, 0, 1)
46+
image = stretch(image, -np.log(ar)) # inverse stratching
47+
image = image.transpose(1, 2, 0)
48+
49+
if return_aspect_ratio:
50+
return image, ar
51+
else:
52+
return image
53+
54+
55+
class Preprocessor(object):
56+
def __init__(self):
57+
self.transform = TransformTestdata(
58+
scaled_size=256, crop_size=224, log_ars=[0.0], preprocess=None)
59+
60+
def apply(self, image):
61+
"""Applies the preprocess for an image.
62+
Note that an input must be a CHW shaped, RGB ordered,
63+
[0, 255] valued image.
64+
"""
65+
return np.array(self.transform(image.transpose(2, 0, 1)))
66+
67+
68+
def download_model(url, subdir_name=None, root_dir_name='ronekko'):
69+
root_dir_path = get_dataset_directory(root_dir_name)
70+
basename = os.path.basename(url)
71+
if subdir_name is None:
72+
subdir_name = ''
73+
save_dir_path = os.path.join(root_dir_path, subdir_name)
74+
save_file_path = os.path.join(save_dir_path, basename)
75+
76+
if not os.path.exists(save_file_path):
77+
cache_path = cached_download(url)
78+
if not os.path.exists(save_dir_path):
79+
os.mkdir(save_dir_path)
80+
os.rename(cache_path, save_file_path)
81+
return save_file_path
82+
83+
84+
if __name__ == '__main__':
85+
# image_filename = 'images/2008_002778_ar1.0.jpg' # original image
86+
# image_filename = 'images/2008_002778_ar2.0.jpg' # 2 times wider
87+
image_filename = 'images/2008_002778_ar0.5.jpg' # 2 times taller
88+
89+
restorer = Restorer()
90+
91+
image = plt.imread(image_filename)
92+
93+
# If you want to only estimate the aspect ratio of the input image,
94+
# use `Restorer.predict_aspect_ratio` method.
95+
aspect_ratio = restorer.predict_aspect_ratio(image)
96+
print(aspect_ratio)
97+
98+
# If you want to restore the input image,
99+
# use `Restorer.restore_image` method.
100+
restored_image, aspect_ratio = restorer.restore_image(image)
101+
plt.imshow(restored_image)
102+
plt.show()
103+
print('restored_image.shape =', restored_image.shape)

0 commit comments

Comments
 (0)