Skip to content

Commit f049960

Browse files
author
Jaan Altosaar
committed
fix jax implementation
1 parent f4c3ece commit f049960

File tree

3 files changed

+60
-49
lines changed

3 files changed

+60
-49
lines changed

README.md

+7-5
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ step: 30000 train elbo: -98.70
3636
step: 30000 valid elbo: -103.76 valid log p(x): -97.71
3737
```
3838

39-
Using jax:
39+
Using jax (anaconda environment is in `environment-jax.yml`):
4040
```
41-
Step 0 Validation ELBO estimate: -507.485 Validation log p(x) estimate: -507.485
42-
Step 10000 Validation ELBO estimate: -152.695 Validation log p(x) estimate: -152.695
43-
Step 20000 Validation ELBO estimate: -150.413 Validation log p(x) estimate: -150.413
44-
Step 30000 Validation ELBO estimate: -150.529 Validation log p(x) estimate: -150.529
41+
Step 0 Train ELBO estimate: -565.785 Validation ELBO estimate: -565.775 Validation log p(x) estimate: -565.775 Speed: 3813003636 examples/s
42+
Step 10000 Train ELBO estimate: -99.048 Validation ELBO estimate: -105.412 Validation log p(x) estimate: -105.412 Speed: 134 examples/s
43+
Step 20000 Train ELBO estimate: -108.399 Validation ELBO estimate: -105.191 Validation log p(x) estimate: -105.191 Speed: 140 examples/s
44+
Step 30000 Train ELBO estimate: -100.839 Validation ELBO estimate: -105.404 Validation log p(x) estimate: -105.404 Speed: 139 examples/s
45+
Step 40000 Train ELBO estimate: -97.761 Validation ELBO estimate: -105.382 Validation log p(x) estimate: -105.382 Speed: 139 examples/s
46+
Step 50000 Train ELBO estimate: -98.228 Validation ELBO estimate: -105.718 Validation log p(x) estimate: -105.718 Speed: 139 examples/s
4547
```

environment_jax.yml renamed to environment-jax.yml

+14-14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: jax
1+
name: /scratch/gpfs/altosaar/environment-jax
22
channels:
33
- defaults
44
dependencies:
@@ -11,7 +11,6 @@ dependencies:
1111
- libstdcxx-ng=9.1.0=hdf63c60_0
1212
- ncurses=6.2=he6710b0_1
1313
- openssl=1.1.1k=h27cfd23_0
14-
- pip=21.1.1=py39h06a4308_0
1514
- python=3.9.5=hdb3f193_3
1615
- readline=8.1=h27cfd23_0
1716
- setuptools=52.0.0=py39h06a4308_0
@@ -36,25 +35,26 @@ dependencies:
3635
- flatbuffers==1.12
3736
- future==0.18.2
3837
- gast==0.4.0
39-
- google-auth==1.30.0
38+
- google-auth==1.30.1
4039
- google-auth-oauthlib==0.4.4
4140
- google-pasta==0.2.0
4241
- googleapis-common-protos==1.53.0
43-
- grpcio==1.34.1
42+
- grpcio==1.38.0
4443
- h5py==3.1.0
4544
- idna==2.10
4645
- jax==0.2.13
47-
- jaxlib==0.1.67
46+
- jaxlib==0.1.67+cuda111
4847
- jmp==0.0.2
49-
- keras-nightly==2.5.0.dev2021032900
48+
- keras-nightly==2.6.0.dev2021052500
5049
- keras-preprocessing==1.1.2
5150
- markdown==3.3.4
5251
- numpy==1.19.5
5352
- oauthlib==3.1.0
5453
- opt-einsum==3.3.0
55-
- optax==0.0.7
54+
- optax==0.0.6
55+
- pip==21.1.2
5656
- promise==2.3
57-
- protobuf==3.17.0
57+
- protobuf==3.17.1
5858
- pyasn1==0.4.8
5959
- pyasn1-modules==0.2.8
6060
- requests==2.25.1
@@ -63,19 +63,19 @@ dependencies:
6363
- scipy==1.6.3
6464
- six==1.15.0
6565
- tabulate==0.8.9
66-
- tensorboard==2.5.0
66+
- tb-nightly==2.6.0a20210525
6767
- tensorboard-data-server==0.6.1
6868
- tensorboard-plugin-wit==1.8.0
69-
- tensorflow==2.5.0
7069
- tensorflow-datasets==4.3.0
71-
- tensorflow-estimator==2.5.0
7270
- tensorflow-metadata==1.0.0
7371
- termcolor==1.1.0
74-
- tfp-nightly==0.14.0.dev20210521
72+
- tf-estimator-nightly==2.5.0.dev2021032601
73+
- tf-nightly==2.6.0.dev20210525
74+
- tfp-nightly==0.14.0.dev20210525
7575
- toolz==0.11.1
76-
- tqdm==4.60.0
76+
- tqdm==4.61.0
7777
- typing-extensions==3.7.4.3
7878
- urllib3==1.26.4
7979
- werkzeug==2.0.1
8080
- wrapt==1.12.1
81-
prefix: /home/jaan/miniconda3/envs/jax
81+
prefix: /scratch/gpfs/altosaar/environment-jax

train_variational_autoencoder_jax.py

+39-30
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
Largely follows https://github.com/deepmind/dm-haiku/blob/master/examples/vae.py"""
44

5+
import time
56
import argparse
67
import pathlib
78
from calendar import c
@@ -26,22 +27,16 @@
2627

2728

2829
def add_args(parser):
29-
parser.add_argument("--latent_size", type=int, default=10)
30+
parser.add_argument("--latent_size", type=int, default=128)
3031
parser.add_argument("--hidden_size", type=int, default=512)
31-
parser.add_argument("--variational", choices=["flow", "mean-field"])
32-
parser.add_argument("--flow_depth", type=int, default=2)
3332
parser.add_argument("--learning_rate", type=float, default=0.001)
3433
parser.add_argument("--batch_size", type=int, default=128)
3534
parser.add_argument("--training_steps", type=int, default=100000)
3635
parser.add_argument("--log_interval", type=int, default=10000)
37-
parser.add_argument("--early_stopping_interval", type=int, default=5)
38-
parser.add_argument("--n_samples", type=int, default=128)
39-
parser.add_argument(
40-
"--use_gpu", default=False, action=argparse.BooleanOptionalAction
41-
)
36+
parser.add_argument("--num_eval_samples", type=int, default=128)
37+
parser.add_argument("--gpu", default=False, action=argparse.BooleanOptionalAction)
4238
parser.add_argument("--random_seed", type=int, default=42)
4339
parser.add_argument("--train_dir", type=pathlib.Path, default="/tmp")
44-
parser.add_argument("--data_dir", type=pathlib.Path, default="/tmp")
4540

4641

4742
def load_dataset(
@@ -78,8 +73,8 @@ def __init__(
7873
[
7974
hk.Linear(self._hidden_size),
8075
jax.nn.relu,
81-
# hk.Linear(self._hidden_size),
82-
# jax.nn.relu,
76+
hk.Linear(self._hidden_size),
77+
jax.nn.relu,
8378
hk.Linear(np.prod(self._output_shape)),
8479
hk.Reshape(self._output_shape, preserve_dims=2),
8580
]
@@ -106,8 +101,8 @@ def __init__(self, latent_size: int, hidden_size: int):
106101
hk.Flatten(),
107102
hk.Linear(self._hidden_size),
108103
jax.nn.relu,
109-
# hk.Linear(self._hidden_size),
110-
# jax.nn.relu,
104+
hk.Linear(self._hidden_size),
105+
jax.nn.relu,
111106
hk.Linear(self._latent_size * 2),
112107
]
113108
)
@@ -187,10 +182,10 @@ def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarr
187182
params = model_and_variational.init(
188183
next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE))
189184
)
190-
optimizer = optax.adam(args.learning_rate)
185+
optimizer = optax.rmsprop(args.learning_rate)
191186
opt_state = optimizer.init(params)
192187

193-
# @jax.jit
188+
@jax.jit
194189
def train_step(
195190
params: hk.Params, rng_key: PRNGKey, opt_state: optax.OptState, batch: Batch
196191
) -> Tuple[hk.Params, optax.OptState]:
@@ -245,24 +240,38 @@ def evaluate(
245240
)
246241
test_ds = load_dataset(tfds.Split.TEST, args.batch_size, args.random_seed)
247242

243+
def print_progress(step: int, examples_per_sec: float):
244+
valid_ds = load_dataset(
245+
tfds.Split.VALIDATION, args.batch_size, args.random_seed
246+
)
247+
elbo, log_p_x = evaluate(valid_ds, params, rng_seq)
248+
train_elbo = (
249+
-objective_fn(params, next(rng_seq), next(train_ds)) / args.batch_size
250+
)
251+
print(
252+
f"Step {step:<10d}\t"
253+
f"Train ELBO estimate: {train_elbo:<5.3f}\t"
254+
f"Validation ELBO estimate: {elbo:<5.3f}\t"
255+
f"Validation log p(x) estimate: {log_p_x:<5.3f}\t"
256+
f"Speed: {examples_per_sec:<5.0f} examples/s"
257+
)
258+
259+
t0 = time.time()
248260
for step in range(args.training_steps):
249-
params, opt_state = train_step(params, next(rng_seq), opt_state, next(train_ds))
250261
if step % args.log_interval == 0:
251-
valid_ds = load_dataset(
252-
tfds.Split.VALIDATION, args.batch_size, args.random_seed
253-
)
254-
elbo, log_p_x = evaluate(valid_ds, params, rng_seq)
255-
train_elbo = (
256-
-objective_fn(params, next(rng_seq), next(train_ds)) / args.batch_size
257-
)
258-
print(
259-
f"Step {step:<10d}\t"
260-
f"Train ELBO estimate: {train_elbo:<5.3f}\t"
261-
f"Validation ELBO estimate: {elbo:<5.3f}\t"
262-
f"Validation log p(x) estimate: {log_p_x:<5.3f}"
263-
)
262+
examples_per_sec = args.log_interval / (time.time() - t0)
263+
print_progress(step, examples_per_sec)
264+
t0 = time.time()
265+
params, opt_state = train_step(params, next(rng_seq), opt_state, next(train_ds))
266+
267+
test_ds = load_dataset(tfds.Split.TEST, args.batch_size, args.random_seed)
268+
elbo, log_p_x = evaluate(test_ds, params, rng_seq)
269+
print(
270+
f"Step {step:<10d}\t"
271+
f"Test ELBO estimate: {elbo:<5.3f}\t"
272+
f"Test log p(x) estimate: {log_p_x:<5.3f}\t"
273+
)
264274

265275

266276
if __name__ == "__main__":
267277
main()
268-

0 commit comments

Comments
 (0)