9
9
import numpy as np
10
10
import cv2
11
11
12
+ try :
13
+ torch .cuda .current_device ()
14
+ TORCH_MODE = 'GPU'
15
+ get_dev_vec = lambda x : x .cuda (0 )
16
+ except AssertionError :
17
+ TORCH_MODE = 'CPU'
18
+ get_dev_vec = lambda x : x .cpu ()
19
+
12
20
class PhotoWCT (nn .Module ):
13
21
def __init__ (self , args ):
14
22
super (PhotoWCT , self ).__init__ ()
@@ -131,8 +139,8 @@ def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg):
131
139
styl_mask = np .where (t_styl_seg .reshape (t_styl_seg .shape [0 ] * t_styl_seg .shape [1 ]) == l )
132
140
if cont_mask [0 ].size <= 0 or styl_mask [0 ].size <= 0 :
133
141
continue
134
- cont_indi = torch .LongTensor (cont_mask [0 ]). cuda ( 0 )
135
- styl_indi = torch .LongTensor (styl_mask [0 ]). cuda ( 0 )
142
+ cont_indi = get_dev_vec ( torch .LongTensor (cont_mask [0 ]))
143
+ styl_indi = get_dev_vec ( torch .LongTensor (styl_mask [0 ]))
136
144
cFFG = torch .index_select (cont_feat_view , 1 , cont_indi )
137
145
sFFG = torch .index_select (styl_feat_view , 1 , styl_indi )
138
146
tmp_target_feature = self .__wct_core (cFFG , sFFG )
@@ -148,7 +156,7 @@ def __wct_core(self, cont_feat, styl_feat):
148
156
c_mean = c_mean .unsqueeze (1 ).expand_as (cont_feat )
149
157
cont_feat = cont_feat - c_mean
150
158
151
- iden = torch .eye (cFSize [0 ]). cuda ( )#.double()
159
+ iden = get_dev_vec ( torch .eye (cFSize [0 ]))#.double()
152
160
contentConv = torch .mm (cont_feat , cont_feat .t ()).div (cFSize [1 ] - 1 ) + iden
153
161
# del iden
154
162
c_u , c_e , c_v = torch .svd (contentConv , some = False )
0 commit comments