2
2
3
3
Largely follows https://github.com/deepmind/dm-haiku/blob/master/examples/vae.py"""
4
4
5
+ import time
5
6
import argparse
6
7
import pathlib
7
8
from calendar import c
26
27
27
28
28
29
def add_args (parser ):
29
- parser .add_argument ("--latent_size" , type = int , default = 10 )
30
+ parser .add_argument ("--latent_size" , type = int , default = 128 )
30
31
parser .add_argument ("--hidden_size" , type = int , default = 512 )
31
- parser .add_argument ("--variational" , choices = ["flow" , "mean-field" ])
32
- parser .add_argument ("--flow_depth" , type = int , default = 2 )
33
32
parser .add_argument ("--learning_rate" , type = float , default = 0.001 )
34
33
parser .add_argument ("--batch_size" , type = int , default = 128 )
35
34
parser .add_argument ("--training_steps" , type = int , default = 100000 )
36
35
parser .add_argument ("--log_interval" , type = int , default = 10000 )
37
- parser .add_argument ("--early_stopping_interval" , type = int , default = 5 )
38
- parser .add_argument ("--n_samples" , type = int , default = 128 )
39
- parser .add_argument (
40
- "--use_gpu" , default = False , action = argparse .BooleanOptionalAction
41
- )
36
+ parser .add_argument ("--num_eval_samples" , type = int , default = 128 )
37
+ parser .add_argument ("--gpu" , default = False , action = argparse .BooleanOptionalAction )
42
38
parser .add_argument ("--random_seed" , type = int , default = 42 )
43
39
parser .add_argument ("--train_dir" , type = pathlib .Path , default = "/tmp" )
44
- parser .add_argument ("--data_dir" , type = pathlib .Path , default = "/tmp" )
45
40
46
41
47
42
def load_dataset (
@@ -78,8 +73,8 @@ def __init__(
78
73
[
79
74
hk .Linear (self ._hidden_size ),
80
75
jax .nn .relu ,
81
- # hk.Linear(self._hidden_size),
82
- # jax.nn.relu,
76
+ hk .Linear (self ._hidden_size ),
77
+ jax .nn .relu ,
83
78
hk .Linear (np .prod (self ._output_shape )),
84
79
hk .Reshape (self ._output_shape , preserve_dims = 2 ),
85
80
]
@@ -106,8 +101,8 @@ def __init__(self, latent_size: int, hidden_size: int):
106
101
hk .Flatten (),
107
102
hk .Linear (self ._hidden_size ),
108
103
jax .nn .relu ,
109
- # hk.Linear(self._hidden_size),
110
- # jax.nn.relu,
104
+ hk .Linear (self ._hidden_size ),
105
+ jax .nn .relu ,
111
106
hk .Linear (self ._latent_size * 2 ),
112
107
]
113
108
)
@@ -187,10 +182,10 @@ def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarr
187
182
params = model_and_variational .init (
188
183
next (rng_seq ), np .zeros ((1 , * MNIST_IMAGE_SHAPE ))
189
184
)
190
- optimizer = optax .adam (args .learning_rate )
185
+ optimizer = optax .rmsprop (args .learning_rate )
191
186
opt_state = optimizer .init (params )
192
187
193
- # @jax.jit
188
+ @jax .jit
194
189
def train_step (
195
190
params : hk .Params , rng_key : PRNGKey , opt_state : optax .OptState , batch : Batch
196
191
) -> Tuple [hk .Params , optax .OptState ]:
@@ -245,24 +240,38 @@ def evaluate(
245
240
)
246
241
test_ds = load_dataset (tfds .Split .TEST , args .batch_size , args .random_seed )
247
242
243
+ def print_progress (step : int , examples_per_sec : float ):
244
+ valid_ds = load_dataset (
245
+ tfds .Split .VALIDATION , args .batch_size , args .random_seed
246
+ )
247
+ elbo , log_p_x = evaluate (valid_ds , params , rng_seq )
248
+ train_elbo = (
249
+ - objective_fn (params , next (rng_seq ), next (train_ds )) / args .batch_size
250
+ )
251
+ print (
252
+ f"Step { step :<10d} \t "
253
+ f"Train ELBO estimate: { train_elbo :<5.3f} \t "
254
+ f"Validation ELBO estimate: { elbo :<5.3f} \t "
255
+ f"Validation log p(x) estimate: { log_p_x :<5.3f} \t "
256
+ f"Speed: { examples_per_sec :<5.0f} examples/s"
257
+ )
258
+
259
+ t0 = time .time ()
248
260
for step in range (args .training_steps ):
249
- params , opt_state = train_step (params , next (rng_seq ), opt_state , next (train_ds ))
250
261
if step % args .log_interval == 0 :
251
- valid_ds = load_dataset (
252
- tfds .Split .VALIDATION , args .batch_size , args .random_seed
253
- )
254
- elbo , log_p_x = evaluate (valid_ds , params , rng_seq )
255
- train_elbo = (
256
- - objective_fn (params , next (rng_seq ), next (train_ds )) / args .batch_size
257
- )
258
- print (
259
- f"Step { step :<10d} \t "
260
- f"Train ELBO estimate: { train_elbo :<5.3f} \t "
261
- f"Validation ELBO estimate: { elbo :<5.3f} \t "
262
- f"Validation log p(x) estimate: { log_p_x :<5.3f} "
263
- )
262
+ examples_per_sec = args .log_interval / (time .time () - t0 )
263
+ print_progress (step , examples_per_sec )
264
+ t0 = time .time ()
265
+ params , opt_state = train_step (params , next (rng_seq ), opt_state , next (train_ds ))
266
+
267
+ test_ds = load_dataset (tfds .Split .TEST , args .batch_size , args .random_seed )
268
+ elbo , log_p_x = evaluate (test_ds , params , rng_seq )
269
+ print (
270
+ f"Step { step :<10d} \t "
271
+ f"Test ELBO estimate: { elbo :<5.3f} \t "
272
+ f"Test log p(x) estimate: { log_p_x :<5.3f} \t "
273
+ )
264
274
265
275
266
276
if __name__ == "__main__" :
267
277
main ()
268
-
0 commit comments