Skip to content

Commit e9bdbbc

Browse files
authored
Merge pull request #4 from borisdayma/fix-flax
fix: handle flax 0.3.6
2 parents 610d842 + f49a07f commit e9bdbbc

File tree

2 files changed

+1
-7
lines changed

2 files changed

+1
-7
lines changed

vqgan_jax/convert_pt_model_to_jax.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,13 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
5050

5151
# Correctly rename weight parameters
5252
if ("norm" in pt_key and (pt_tuple_key[-1] == "bias")
53-
and (pt_tuple_key[:-1] + ("bias", ) in random_flax_state_dict)):
54-
pt_tensor = pt_tensor[None, None, None, :]
55-
elif ("norm" in pt_key and (pt_tuple_key[-1] == "bias")
53+
and (pt_tuple_key[:-1] + ("bias", ) not in random_flax_state_dict)
5654
and (pt_tuple_key[:-1] + ("scale", ) in random_flax_state_dict)):
5755
pt_tuple_key = pt_tuple_key[:-1] + ("scale", )
58-
pt_tensor = pt_tensor[None, None, None, :]
5956
elif pt_tuple_key[-1] in [
6057
"weight", "gamma"
6158
] and pt_tuple_key[:-1] + ("scale", ) in random_flax_state_dict:
6259
pt_tuple_key = pt_tuple_key[:-1] + ("scale", )
63-
pt_tensor = pt_tensor[None, None, None, :]
6460
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + (
6561
"embedding", ) in random_flax_state_dict:
6662
pt_tuple_key = pt_tuple_key[:-1] + ("embedding", )

vqgan_jax/modeling_flax_vqgan.py

-2
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,6 @@ def setup(self):
399399
- 1]
400400
curr_res = self.config.resolution // 2**(self.config.num_resolutions - 1)
401401
self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
402-
print("Working with z of shape {} = {} dimensions.".format(
403-
self.z_shape, np.prod(self.z_shape)))
404402

405403
# z to block_in
406404
self.conv_in = nn.Conv(

0 commit comments

Comments
 (0)