|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Created on Thu Jul 26 19:23:17 2018 |
| 4 | +
|
| 5 | +@author: sakurai |
| 6 | +""" |
| 7 | +import os |
| 8 | + |
| 9 | +import matplotlib.pyplot as plt |
| 10 | +import numpy as np |
| 11 | + |
| 12 | +import chainer |
| 13 | +from chainer.dataset.download import get_dataset_directory, cached_download |
| 14 | + |
| 15 | +from evaluation import stretch |
| 16 | +from evaluation import TransformTestdata |
| 17 | +import main_resnet |
| 18 | + |
| 19 | + |
| 20 | +class Restorer(object): |
| 21 | + |
| 22 | + _model_url = ( |
| 23 | + 'https://github.com/ronekko/model_parameters/releases/download/v1/' |
| 24 | + 'resnet101-20171206T000238-4c22664.npz') |
| 25 | + |
| 26 | + def __init__(self): |
| 27 | + self.net = main_resnet.Resnet( |
| 28 | + 32, [3, 4, 5, 6], [64, 128, 256, 512], False) |
| 29 | + |
| 30 | + npz_path = download_model( |
| 31 | + self._model_url, 'aspect_ratio_restorer') |
| 32 | + chainer.serializers.load_npz(npz_path, self.net) |
| 33 | + self.preprocessor = Preprocessor() |
| 34 | + |
| 35 | + def predict_aspect_ratio(self, image): |
| 36 | + preprocessed_array = self.preprocessor.apply(image) |
| 37 | + with chainer.no_backprop_mode(): |
| 38 | + with chainer.using_config('train', False): |
| 39 | + log_ar = self.net(preprocessed_array).array |
| 40 | + ar = self.net.xp.exp(log_ar) |
| 41 | + return ar[0, 0] |
| 42 | + |
| 43 | + def restore_image(self, image, return_aspect_ratio=True): |
| 44 | + ar = self.predict_aspect_ratio(image) |
| 45 | + image = image.transpose(2, 0, 1) |
| 46 | + image = stretch(image, -np.log(ar)) # inverse stratching |
| 47 | + image = image.transpose(1, 2, 0) |
| 48 | + |
| 49 | + if return_aspect_ratio: |
| 50 | + return image, ar |
| 51 | + else: |
| 52 | + return image |
| 53 | + |
| 54 | + |
| 55 | +class Preprocessor(object): |
| 56 | + def __init__(self): |
| 57 | + self.transform = TransformTestdata( |
| 58 | + scaled_size=256, crop_size=224, log_ars=[0.0], preprocess=None) |
| 59 | + |
| 60 | + def apply(self, image): |
| 61 | + """Applies the preprocess for an image. |
| 62 | + Note that an input must be a CHW shaped, RGB ordered, |
| 63 | + [0, 255] valued image. |
| 64 | + """ |
| 65 | + return np.array(self.transform(image.transpose(2, 0, 1))) |
| 66 | + |
| 67 | + |
| 68 | +def download_model(url, subdir_name=None, root_dir_name='ronekko'): |
| 69 | + root_dir_path = get_dataset_directory(root_dir_name) |
| 70 | + basename = os.path.basename(url) |
| 71 | + if subdir_name is None: |
| 72 | + subdir_name = '' |
| 73 | + save_dir_path = os.path.join(root_dir_path, subdir_name) |
| 74 | + save_file_path = os.path.join(save_dir_path, basename) |
| 75 | + |
| 76 | + if not os.path.exists(save_file_path): |
| 77 | + cache_path = cached_download(url) |
| 78 | + if not os.path.exists(save_dir_path): |
| 79 | + os.mkdir(save_dir_path) |
| 80 | + os.rename(cache_path, save_file_path) |
| 81 | + return save_file_path |
| 82 | + |
| 83 | + |
| 84 | +if __name__ == '__main__': |
| 85 | +# image_filename = 'images/2008_002778_ar1.0.jpg' # original image |
| 86 | +# image_filename = 'images/2008_002778_ar2.0.jpg' # 2 times wider |
| 87 | + image_filename = 'images/2008_002778_ar0.5.jpg' # 2 times taller |
| 88 | + |
| 89 | + restorer = Restorer() |
| 90 | + |
| 91 | + image = plt.imread(image_filename) |
| 92 | + |
| 93 | + # If you want to only estimate the aspect ratio of the input image, |
| 94 | + # use `Restorer.predict_aspect_ratio` method. |
| 95 | + aspect_ratio = restorer.predict_aspect_ratio(image) |
| 96 | + print(aspect_ratio) |
| 97 | + |
| 98 | + # If you want to restore the input image, |
| 99 | + # use `Restorer.restore_image` method. |
| 100 | + restored_image, aspect_ratio = restorer.restore_image(image) |
| 101 | + plt.imshow(restored_image) |
| 102 | + plt.show() |
| 103 | + print('restored_image.shape =', restored_image.shape) |
0 commit comments