1
+ import os
2
+ os .environ ['TF_CPP_MIN_LOG_LEVEL' ]= '2'
3
+ import time
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+
8
+ import load_vgg_sol
9
+ import utils
10
+
11
+ def setup ():
12
+ utils .safe_mkdir ('checkpoints' )
13
+ utils .safe_mkdir ('outputs' )
14
+
15
+ class StyleTransfer (object ):
16
+ def __init__ (self , content_img , style_img , img_width , img_height ):
17
+ '''
18
+ img_width and img_height are the dimensions we expect from the generated image.
19
+ We will resize input content image and input style image to match this dimension.
20
+ Feel free to alter any hyperparameter here and see how it affects your training.
21
+ '''
22
+ self .img_width = img_width
23
+ self .img_height = img_height
24
+ self .content_img = utils .get_resized_image (content_img , img_width , img_height )
25
+ self .style_img = utils .get_resized_image (style_img , img_width , img_height )
26
+ self .initial_img = utils .generate_noise_image (self .content_img , img_width , img_height )
27
+
28
+ ###############################
29
+ ## TO DO
30
+ ## create global step (gstep) and hyperparameters for the model
31
+ self .content_layer = 'conv4_2'
32
+ self .style_layers = ['conv1_1' , 'conv2_1' , 'conv3_1' , 'conv4_1' , 'conv5_1' ]
33
+ self .content_w = 0.01
34
+ self .style_w = 1
35
+ self .style_layer_w = [0.5 , 1.0 , 1.5 , 3.0 , 4.0 ]
36
+ self .gstep = tf .Variable (0 , dtype = tf .int32 ,
37
+ trainable = False , name = 'global_step' )
38
+ self .lr = 2.0
39
+ ###############################
40
+
41
+ def create_input (self ):
42
+ '''
43
+ We will use one input_img as a placeholder for the content image,
44
+ style image, and generated image, because:
45
+ 1. they have the same dimension
46
+ 2. we have to extract the same set of features from them
47
+ We use a variable instead of a placeholder because we're, at the same time,
48
+ training the generated image to get the desirable result.
49
+
50
+ Note: image height corresponds to number of rows, not columns.
51
+ '''
52
+ with tf .variable_scope ('input' ) as scope :
53
+ self .input_img = tf .get_variable ('in_img' ,
54
+ shape = ([1 , self .img_height , self .img_width , 3 ]),
55
+ dtype = tf .float32 ,
56
+ initializer = tf .zeros_initializer ())
57
+ def load_vgg (self ):
58
+ '''
59
+ Load the saved model parameters of VGG-19, using the input_img
60
+ as the input to compute the output at each layer of vgg.
61
+
62
+ During training, VGG-19 mean-centered all images and found the mean pixels
63
+ to be [123.68, 116.779, 103.939] along RGB dimensions. We have to subtract
64
+ this mean from our images.
65
+
66
+ '''
67
+ self .vgg = load_vgg_sol .VGG (self .input_img )
68
+ self .vgg .load ()
69
+ self .content_img -= self .vgg .mean_pixels
70
+ self .style_img -= self .vgg .mean_pixels
71
+
72
+ def _content_loss (self , P , F ):
73
+ ''' Calculate the loss between the feature representation of the
74
+ content image and the generated image.
75
+
76
+ Inputs:
77
+ P: content representation of the content image
78
+ F: content representation of the generated image
79
+ Read the assignment handout for more details
80
+
81
+ Note: Don't use the coefficient 0.5 as defined in the paper.
82
+ Use the coefficient defined in the assignment handout.
83
+ '''
84
+ # self.content_loss = None
85
+ ###############################
86
+ ## TO DO
87
+ self .content_loss = tf .reduce_sum ((F - P ) ** 2 ) / (4.0 * P .size )
88
+ ###############################
89
+
90
+ def _gram_matrix (self , F , N , M ):
91
+ """ Create and return the gram matrix for tensor F
92
+ Hint: you'll first have to reshape F
93
+ """
94
+ ###############################
95
+ ## TO DO
96
+ F = tf .reshape (F , (M , N ))
97
+ return tf .matmul (tf .transpose (F ), F )
98
+ ###############################
99
+
100
+ def _single_style_loss (self , a , g ):
101
+ """ Calculate the style loss at a certain layer
102
+ Inputs:
103
+ a is the feature representation of the style image at that layer
104
+ g is the feature representation of the generated image at that layer
105
+ Output:
106
+ the style loss at a certain layer (which is E_l in the paper)
107
+
108
+ Hint: 1. you'll have to use the function _gram_matrix()
109
+ 2. we'll use the same coefficient for style loss as in the paper
110
+ 3. a and g are feature representation, not gram matrices
111
+ """
112
+ ###############################
113
+ ## TO DO
114
+ N = a .shape [3 ] # number of filters
115
+ M = a .shape [1 ] * a .shape [2 ] # height times width of the feature map
116
+ A = self ._gram_matrix (a , N , M )
117
+ G = self ._gram_matrix (g , N , M )
118
+ return tf .reduce_sum ((G - A ) ** 2 / ((2 * N * M ) ** 2 ))
119
+ ###############################
120
+
121
+ def _style_loss (self , A ):
122
+ """ Calculate the total style loss as a weighted sum
123
+ of style losses at all style layers
124
+ Hint: you'll have to use _single_style_loss()
125
+ """
126
+ n_layers = len (A )
127
+ E = [self ._single_style_loss (A [i ], getattr (self .vgg , self .style_layers [i ])) for i in range (n_layers )]
128
+
129
+ ###############################
130
+ ## TO DO
131
+ self .style_loss = sum ([self .style_layer_w [i ] * E [i ] for i in range (n_layers )])
132
+ ###############################
133
+
134
+ def losses (self ):
135
+ with tf .variable_scope ('losses' ) as scope :
136
+ with tf .Session () as sess :
137
+ # assign content image to the input variable
138
+ sess .run (self .input_img .assign (self .content_img ))
139
+ gen_img_content = getattr (self .vgg , self .content_layer )
140
+ content_img_content = sess .run (gen_img_content )
141
+ self ._content_loss (content_img_content , gen_img_content )
142
+
143
+ with tf .Session () as sess :
144
+ sess .run (self .input_img .assign (self .style_img ))
145
+ style_layers = sess .run ([getattr (self .vgg , layer ) for layer in self .style_layers ])
146
+ self ._style_loss (style_layers )
147
+
148
+ ##########################################
149
+ ## TO DO: create total loss.
150
+ ## Hint: don't forget the weights for the content loss and style loss
151
+ self .total_loss = self .content_w * self .content_loss + self .style_w * self .style_loss
152
+ ##########################################
153
+
154
+ def optimize (self ):
155
+ ###############################
156
+ ## TO DO: create optimizer
157
+ self .opt = tf .train .AdamOptimizer (self .lr ).minimize (self .total_loss ,
158
+ global_step = self .gstep )
159
+ ###############################
160
+
161
+ def create_summary (self ):
162
+ ###############################
163
+ ## TO DO: create summaries for all the losses
164
+ ## Hint: don't forget to merge them
165
+ with tf .name_scope ('summaries' ):
166
+ tf .summary .scalar ('content loss' , self .content_loss )
167
+ tf .summary .scalar ('style loss' , self .style_loss )
168
+ tf .summary .scalar ('total loss' , self .total_loss )
169
+ self .summary_op = tf .summary .merge_all ()
170
+ ###############################
171
+
172
+
173
+ def build (self ):
174
+ self .create_input ()
175
+ self .load_vgg ()
176
+ self .losses ()
177
+ self .optimize ()
178
+ self .create_summary ()
179
+
180
+ def train (self , n_iters ):
181
+ skip_step = 1
182
+ with tf .Session () as sess :
183
+
184
+ ###############################
185
+ ## TO DO:
186
+ ## 1. initialize your variables
187
+ ## 2. create writer to write your graph
188
+ sess .run (tf .global_variables_initializer ())
189
+ writer = tf .summary .FileWriter ('graphs/style_stranfer' , sess .graph )
190
+ ###############################
191
+ sess .run (self .input_img .assign (self .initial_img ))
192
+
193
+
194
+ ###############################
195
+ ## TO DO:
196
+ ## 1. create a saver object
197
+ ## 2. check if a checkpoint exists, restore the variables
198
+ saver = tf .train .Saver ()
199
+ ckpt = tf .train .get_checkpoint_state (os .path .dirname ('checkpoints/style_transfer/checkpoint' ))
200
+ if ckpt and ckpt .model_checkpoint_path :
201
+ saver .restore (sess , ckpt .model_checkpoint_path )
202
+ ##############################
203
+
204
+ initial_step = self .gstep .eval ()
205
+
206
+ start_time = time .time ()
207
+ for index in range (initial_step , n_iters ):
208
+ if index >= 5 and index < 20 :
209
+ skip_step = 10
210
+ elif index >= 20 :
211
+ skip_step = 20
212
+
213
+ sess .run (self .opt )
214
+ if (index + 1 ) % skip_step == 0 :
215
+ ###############################
216
+ ## TO DO: obtain generated image, loss, and summary
217
+ gen_image , total_loss , summary = sess .run ([self .input_img ,
218
+ self .total_loss ,
219
+ self .summary_op ])
220
+
221
+ ###############################
222
+
223
+ # add back the mean pixels we subtracted before
224
+ gen_image = gen_image + self .vgg .mean_pixels
225
+ writer .add_summary (summary , global_step = index )
226
+ print ('Step {}\n Sum: {:5.1f}' .format (index + 1 , np .sum (gen_image )))
227
+ print (' Loss: {:5.1f}' .format (total_loss ))
228
+ print (' Took: {} seconds' .format (time .time () - start_time ))
229
+ start_time = time .time ()
230
+
231
+ filename = 'outputs/%d.png' % (index )
232
+ utils .save_image (filename , gen_image )
233
+
234
+ if (index + 1 ) % 20 == 0 :
235
+ ###############################
236
+ ## TO DO: save the variables into a checkpoint
237
+ saver .save (sess , 'checkpoints/style_stranfer/style_transfer' , index )
238
+ ###############################
239
+
240
+ if __name__ == '__main__' :
241
+ setup ()
242
+ machine = StyleTransfer ('content/deadpool.jpg' , 'styles/guernica.jpg' , 333 , 250 )
243
+ machine .build ()
244
+ machine .train (300 )
0 commit comments