-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbiggan.py
More file actions
77 lines (62 loc) · 2.95 KB
/
biggan.py
File metadata and controls
77 lines (62 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
import numpy as np
from itertools import cycle
# Plus réaliste
# MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-deep-512/2'
# Plus éthéré
MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-512/2'
class BigGAN(object):
def __init__(self, module_path=MODULE_PATH):
#tf.compat.v1.test.is_gpu_available
tf.reset_default_graph()
print('Loading BigGAN module from:', module_path)
#-----------------------------------------------------------------
# fix "RuntimeError: Exporting/importing meta graphs is not
# supported when eager execution is enabled." error when importing
# the tfhub module
tf.disable_eager_execution()
#-----------------------------------------------------------------
module = hub.Module(module_path)
self.inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
for k, v in module.get_input_info_dict().items()}
self.input_z = self.inputs['z']
self.dim_z = self.input_z.shape.as_list()[1]
self.input_y = self.inputs['y']
self.vocab_size = self.input_y.shape.as_list()[1] # dimension of y (aka label count)
self.input_trunc = self.inputs['truncation']
self.output = module(self.inputs)
# initialize/instantiate tf variables
initializer = tf.global_variables_initializer()
#-----------------------------------------------------------------
# fix "could not create cudnn handle" error
# see: https://github.com/tensorflow/tensorflow/issues/24496
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
#-----------------------------------------------------------------
self.sess = tf.Session(config=config)
self.sess.run(initializer)
# NOTE: use save callback to save images once per batch. return type changes to None.
def sample(self, vectors, labels, truncation, batch_size=1, save_callback=None):
num = vectors.shape[0]
# deal with scalar input case
truncation = np.asarray(truncation)
if truncation.ndim == 0:
truncation = cycle([truncation])
ims = []
for batch_start, trunc in zip(range(0, num, batch_size), truncation):
s = slice(batch_start, min(num, batch_start + batch_size))
feed_dict = {self.input_z: vectors[s], self.input_y: labels[s], self.input_trunc: trunc}
ims_batch = self.sess.run(self.output, feed_dict=feed_dict)
ims_batch = np.clip(((ims_batch + 1) / 2.0) * 256, 0, 511)
ims_batch = np.uint8(ims_batch)
if save_callback is None:
ims.append(ims_batch)
else:
save_callback(ims_batch)
if save_callback is None:
ims = np.concatenate(ims, axis=0)
assert ims.shape[0] == num
return ims
else:
return None