@@ -32,7 +32,8 @@ def __exit__(self, exc_type, exc_value, exc_tb):
32
32
print (self .msg % (time .time () - self .start_time ))
33
33
34
34
35
- def stylization (p_wct , content_image_path , style_image_path , content_seg_path , style_seg_path , output_image_path ):
35
+ def stylization (p_wct , content_image_path , style_image_path , content_seg_path , style_seg_path , output_image_path ,
36
+ cuda ):
36
37
# Load image
37
38
cont_img = Image .open (content_image_path ).convert ('RGB' )
38
39
styl_img = Image .open (style_image_path ).convert ('RGB' )
@@ -45,8 +46,14 @@ def stylization(p_wct, content_image_path, style_image_path, content_seg_path, s
45
46
46
47
cont_img = transforms .ToTensor ()(cont_img ).unsqueeze (0 )
47
48
styl_img = transforms .ToTensor ()(styl_img ).unsqueeze (0 )
48
- cont_img = Variable (cont_img .cuda (0 ), volatile = True )
49
- styl_img = Variable (styl_img .cuda (0 ), volatile = True )
49
+
50
+ if cuda :
51
+ cont_img = cont_img .cuda (0 )
52
+ styl_img = styl_img .cuda (0 )
53
+ p_wct .cuda (0 )
54
+
55
+ cont_img = Variable (cont_img , volatile = True )
56
+ styl_img = Variable (styl_img , volatile = True )
50
57
51
58
cont_seg = np .asarray (cont_seg )
52
59
styl_seg = np .asarray (styl_seg )
@@ -59,6 +66,10 @@ def stylization(p_wct, content_image_path, style_image_path, content_seg_path, s
59
66
out_img = p_pro .process (output_image_path , content_image_path )
60
67
out_img .save (output_image_path )
61
68
69
+ if not cuda :
70
+ print ("NotImplemented: The CPU version of smooth filter has not been implemented currently." )
71
+ return
72
+
62
73
with Timer ("Elapsed time in post processing: %f" ):
63
74
out_img = smooth_filter (output_image_path , content_image_path , f_radius = 15 , f_edge = 1e-1 )
64
75
out_img .save (output_image_path )
0 commit comments