Skip to content

Commit 4cb924a

Browse files
committed
updating demo notebook, making postBuild executable
1 parent 6fd2003 commit 4cb924a

File tree

3 files changed

+209
-51
lines changed

3 files changed

+209
-51
lines changed

notebooks/demo.ipynb

+196-46
Large diffs are not rendered by default.

photo_wct.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
import numpy as np
1010
import cv2
1111

12+
try:
13+
torch.cuda.current_device()
14+
TORCH_MODE = 'GPU'
15+
get_dev_vec = lambda x: x.cuda(0)
16+
except AssertionError:
17+
TORCH_MODE = 'CPU'
18+
get_dev_vec = lambda x: x.cpu()
19+
1220
class PhotoWCT(nn.Module):
1321
def __init__(self, args):
1422
super(PhotoWCT, self).__init__()
@@ -131,8 +139,8 @@ def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg):
131139
styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l)
132140
if cont_mask[0].size <= 0 or styl_mask[0].size <= 0 :
133141
continue
134-
cont_indi = torch.LongTensor(cont_mask[0]).cuda(0)
135-
styl_indi = torch.LongTensor(styl_mask[0]).cuda(0)
142+
cont_indi = get_dev_vec(torch.LongTensor(cont_mask[0]))
143+
styl_indi = get_dev_vec(torch.LongTensor(styl_mask[0]))
136144
cFFG = torch.index_select(cont_feat_view, 1, cont_indi)
137145
sFFG = torch.index_select(styl_feat_view, 1, styl_indi)
138146
tmp_target_feature = self.__wct_core(cFFG, sFFG)
@@ -148,7 +156,7 @@ def __wct_core(self, cont_feat, styl_feat):
148156
c_mean = c_mean.unsqueeze(1).expand_as(cont_feat)
149157
cont_feat = cont_feat - c_mean
150158

151-
iden = torch.eye(cFSize[0]).cuda()#.double()
159+
iden = get_dev_vec(torch.eye(cFSize[0]))#.double()
152160
contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden
153161
# del iden
154162
c_u, c_e, c_v = torch.svd(contentConv, some=False)

postBuild_temp postBuild

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
curl -L "https://www.dropbox.com/s/dlmpr7wabehq3x0/models.zip?dl=1" > models.zip
2-
unzip models.zip
1+
curl -L "https://www.dropbox.com/s/dlmpr7wabehq3x0/models.zip?dl=1" > models.zip
2+
unzip models.zip

0 commit comments

Comments
 (0)