Skip to content

Commit 11b53a4

Browse files
committed
fix: handle flax 0.3.6
1 parent 610d842 commit 11b53a4

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

vqgan_jax/convert_pt_model_to_jax.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,13 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
5050

5151
# Correctly rename weight parameters
5252
if ("norm" in pt_key and (pt_tuple_key[-1] == "bias")
53-
and (pt_tuple_key[:-1] + ("bias", ) in random_flax_state_dict)):
54-
pt_tensor = pt_tensor[None, None, None, :]
55-
elif ("norm" in pt_key and (pt_tuple_key[-1] == "bias")
53+
and (pt_tuple_key[:-1] + ("bias", ) not in random_flax_state_dict)
5654
and (pt_tuple_key[:-1] + ("scale", ) in random_flax_state_dict)):
5755
pt_tuple_key = pt_tuple_key[:-1] + ("scale", )
58-
pt_tensor = pt_tensor[None, None, None, :]
5956
elif pt_tuple_key[-1] in [
6057
"weight", "gamma"
6158
] and pt_tuple_key[:-1] + ("scale", ) in random_flax_state_dict:
6259
pt_tuple_key = pt_tuple_key[:-1] + ("scale", )
63-
pt_tensor = pt_tensor[None, None, None, :]
6460
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + (
6561
"embedding", ) in random_flax_state_dict:
6662
pt_tuple_key = pt_tuple_key[:-1] + ("embedding", )

0 commit comments

Comments
 (0)