Skip to content

Commit 1be20ee

Browse files
authored
Merge pull request #7 from borisdayma/fix_do_init
fix: uses _do_init arg
2 parents e9bdbbc + a9c5735 commit 1be20ee

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vqgan_jax/modeling_flax_vqgan.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -589,14 +589,16 @@ def __init__(
589589
input_shape: Tuple = (1, 256, 256, 3),
590590
seed: int = 0,
591591
dtype: jnp.dtype = jnp.float32,
592+
_do_init: bool = True,
592593
**kwargs,
593594
):
594595
module = self.module_class(config=config, dtype=dtype, **kwargs)
595596
super().__init__(config,
596597
module,
597598
input_shape=input_shape,
598599
seed=seed,
599-
dtype=dtype)
600+
dtype=dtype,
601+
_do_init=_do_init)
600602

601603
def init_weights(self, rng: jax.random.PRNGKey,
602604
input_shape: Tuple) -> FrozenDict:

0 commit comments

Comments
 (0)