forked from roimehrez/contextualLoss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvgg_model.py
59 lines (50 loc) · 3.11 KB
/
vgg_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# ---------------------------------------------------
# code credits: https://github.com/CQFIO/PhotographicImageSynthesis
# ---------------------------------------------------
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import scipy.io
from config import *
def lrelu(x):
return tf.maximum(0.2 * x, x)
def build_net(ntype, nin, nwb=None, name=None):
if ntype == 'conv':
return tf.nn.relu(tf.nn.conv2d(nin, nwb[0], strides=[1, 1, 1, 1], padding='SAME', name=name) + nwb[1])
elif ntype == 'pool':
return tf.nn.avg_pool(nin, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
def get_weight_bias(vgg_layers, i):
weights = vgg_layers[i][0][0][2][0][0]
weights = tf.constant(weights)
bias = vgg_layers[i][0][0][2][0][1]
bias = tf.constant(np.reshape(bias, (bias.size)))
return weights, bias
def build_vgg19(input, reuse=False):
if reuse:
tf.get_variable_scope().reuse_variables()
net = {}
vgg_rawnet = scipy.io.loadmat(config.vgg_model_path)
vgg_layers = vgg_rawnet['layers'][0]
net['input'] = input - np.array([123.6800, 116.7790, 103.9390]).reshape((1, 1, 1, 3))
net['conv1_1'] = build_net('conv', net['input'], get_weight_bias(vgg_layers, 0), name='vgg_conv1_1')
net['conv1_2'] = build_net('conv', net['conv1_1'], get_weight_bias(vgg_layers, 2), name='vgg_conv1_2')
net['pool1'] = build_net('pool', net['conv1_2'])
net['conv2_1'] = build_net('conv', net['pool1'], get_weight_bias(vgg_layers, 5), name='vgg_conv2_1')
net['conv2_2'] = build_net('conv', net['conv2_1'], get_weight_bias(vgg_layers, 7), name='vgg_conv2_2')
net['pool2'] = build_net('pool', net['conv2_2'])
net['conv3_1'] = build_net('conv', net['pool2'], get_weight_bias(vgg_layers, 10), name='vgg_conv3_1')
net['conv3_2'] = build_net('conv', net['conv3_1'], get_weight_bias(vgg_layers, 12), name='vgg_conv3_2')
net['conv3_3'] = build_net('conv', net['conv3_2'], get_weight_bias(vgg_layers, 14), name='vgg_conv3_3')
net['conv3_4'] = build_net('conv', net['conv3_3'], get_weight_bias(vgg_layers, 16), name='vgg_conv3_4')
net['pool3'] = build_net('pool', net['conv3_4'])
net['conv4_1'] = build_net('conv', net['pool3'], get_weight_bias(vgg_layers, 19), name='vgg_conv4_1')
net['conv4_2'] = build_net('conv', net['conv4_1'], get_weight_bias(vgg_layers, 21), name='vgg_conv4_2')
net['conv4_3'] = build_net('conv', net['conv4_2'], get_weight_bias(vgg_layers, 23), name='vgg_conv4_3')
net['conv4_4'] = build_net('conv', net['conv4_3'], get_weight_bias(vgg_layers, 25), name='vgg_conv4_4')
net['pool4'] = build_net('pool', net['conv4_4'])
net['conv5_1'] = build_net('conv', net['pool4'], get_weight_bias(vgg_layers, 28), name='vgg_conv5_1')
net['conv5_2'] = build_net('conv', net['conv5_1'], get_weight_bias(vgg_layers, 30), name='vgg_conv5_2')
net['conv5_3'] = build_net('conv', net['conv5_2'], get_weight_bias(vgg_layers, 32), name='vgg_conv5_3')
net['conv5_4'] = build_net('conv', net['conv5_3'], get_weight_bias(vgg_layers, 34), name='vgg_conv5_4')
net['pool5'] = build_net('pool', net['conv5_4'])
return net