Skip to content

Commit 22c0efb

Browse files
suquarkmingyuliutw
authored andcommitted
Bug fixing & add cuda options (NVIDIA#28)
* Fix an incorrect number. * Add CUDA options
1 parent f895e62 commit 22c0efb

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

demo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
parser.add_argument('--style_image_path', default='./images/style1.png')
2020
parser.add_argument('--style_seg_path', default=[])
2121
parser.add_argument('--output_image_path', default='./results/example1.png')
22+
parser.add_argument('--cuda', type=bool, default=True, help='Enable CUDA.')
2223
args = parser.parse_args()
2324

2425
# Load model
@@ -28,7 +29,7 @@
2829
except:
2930
print("Fail to load PhotoWCT models. PhotoWCT submodule not updated?")
3031
exit()
31-
32+
3233
p_wct.cuda(0)
3334

3435
process_stylization.stylization(
@@ -38,4 +39,5 @@
3839
content_seg_path=args.content_seg_path,
3940
style_seg_path=args.style_seg_path,
4041
output_image_path=args.output_image_path,
42+
cuda=args.cuda,
4143
)

models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def forward_multiple(self, x):
163163
out = self.conv3_1(out)
164164
out = self.relu3_1(out)
165165

166-
if self.level < 3: return out, out2, out1
166+
if self.level < 4: return out, out2, out1
167167

168168
out3 = out
169169

process_stylization.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def __exit__(self, exc_type, exc_value, exc_tb):
3232
print(self.msg % (time.time() - self.start_time))
3333

3434

35-
def stylization(p_wct, content_image_path, style_image_path, content_seg_path, style_seg_path, output_image_path):
35+
def stylization(p_wct, content_image_path, style_image_path, content_seg_path, style_seg_path, output_image_path,
36+
cuda):
3637
# Load image
3738
cont_img = Image.open(content_image_path).convert('RGB')
3839
styl_img = Image.open(style_image_path).convert('RGB')
@@ -45,8 +46,14 @@ def stylization(p_wct, content_image_path, style_image_path, content_seg_path, s
4546

4647
cont_img = transforms.ToTensor()(cont_img).unsqueeze(0)
4748
styl_img = transforms.ToTensor()(styl_img).unsqueeze(0)
48-
cont_img = Variable(cont_img.cuda(0), volatile=True)
49-
styl_img = Variable(styl_img.cuda(0), volatile=True)
49+
50+
if cuda:
51+
cont_img = cont_img.cuda(0)
52+
styl_img = styl_img.cuda(0)
53+
p_wct.cuda(0)
54+
55+
cont_img = Variable(cont_img, volatile=True)
56+
styl_img = Variable(styl_img, volatile=True)
5057

5158
cont_seg = np.asarray(cont_seg)
5259
styl_seg = np.asarray(styl_seg)
@@ -59,6 +66,10 @@ def stylization(p_wct, content_image_path, style_image_path, content_seg_path, s
5966
out_img = p_pro.process(output_image_path, content_image_path)
6067
out_img.save(output_image_path)
6168

69+
if not cuda:
70+
print("NotImplemented: The CPU version of smooth filter has not been implemented currently.")
71+
return
72+
6273
with Timer("Elapsed time in post processing: %f"):
6374
out_img = smooth_filter(output_image_path, content_image_path, f_radius=15, f_edge=1e-1)
6475
out_img.save(output_image_path)

process_stylization_examples.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
parser = argparse.ArgumentParser(description='Photorealistic Image Stylization')
1616
parser.add_argument('--model', default='./PhotoWCTModels/photo_wct.pth',
1717
help='Path to the PhotoWCT model. These are provided by the PhotoWCT submodule, please use `git submodule update --init --recursive` to pull.')
18+
parser.add_argument('--cuda', type=bool, default=True, help='Enable CUDA.')
1819
args = parser.parse_args()
1920

2021
folder = 'examples'
@@ -29,7 +30,6 @@
2930
# Load model
3031
p_wct = PhotoWCT()
3132
p_wct.load_state_dict(torch.load(args.model))
32-
p_wct.cuda(0)
3333

3434
for f in cont_img_list:
3535
print("Process " + f)
@@ -47,4 +47,5 @@
4747
content_seg_path=content_seg_path,
4848
style_seg_path=style_seg_path,
4949
output_image_path=output_image_path,
50+
cuda=args.cuda,
5051
)

0 commit comments

Comments
 (0)