forked from MeditatorE/Cartoon-Converter-Platform
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
74 lines (66 loc) · 2.11 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import os
import numpy as np
import argparse
from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.utils as vutils
from network.Transformer import Transformer
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', default = 'Upload')
parser.add_argument('--load_size', default = 450)
parser.add_argument('--model_path', default = './pretrained_model')
parser.add_argument('--style', default = 'Hayao')
parser.add_argument('--output_dir', default = 'store')
parser.add_argument('--gpu', type=int, default = 0)
opt = parser.parse_args()
valid_ext = ['.jpg', '.png']
if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)
# load pretrained model
model = Transformer()
model.load_state_dict(torch.load(os.path.join(opt.model_path, opt.style + '_net_G_float.pth')))
model.eval()
if opt.gpu > -1:
print('GPU mode')
model.cuda()
else:
print('CPU mode')
model.float()
for files in os.listdir(opt.input_dir):
ext = os.path.splitext(files)[1]
if ext not in valid_ext:
continue
# load image
input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
# resize image, keep aspect ratio
h = input_image.size[0]
w = input_image.size[1]
ratio = h *1.0 / w
if ratio > 1:
h = opt.load_size
w = int(h*1.0/ratio)
else:
w = opt.load_size
h = int(w * ratio)
input_image = input_image.resize((h, w), Image.BICUBIC)
input_image = np.asarray(input_image)
# RGB -> BGR
input_image = input_image[:, :, [2, 1, 0]]
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
# preprocess, (-1, 1)
input_image = -1 + 2 * input_image
if opt.gpu > -1:
input_image = Variable(input_image, volatile=True).cuda()
else:
input_image = Variable(input_image, volatile=True).float()
# forward
output_image = model(input_image)
output_image = output_image[0]
# BGR -> RGB
output_image = output_image[[2, 1, 0], :, :]
# deprocess, (0, 1)
output_image = output_image.data.cpu().float() * 0.5 + 0.5
# save
vutils.save_image(output_image, os.path.join(opt.output_dir, files[:-4] + '_' + opt.style + '.jpg'))
print('Done!')