@@ -50,17 +50,13 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
50
50
51
51
# Correctly rename weight parameters
52
52
if ("norm" in pt_key and (pt_tuple_key [- 1 ] == "bias" )
53
- and (pt_tuple_key [:- 1 ] + ("bias" , ) in random_flax_state_dict )):
54
- pt_tensor = pt_tensor [None , None , None , :]
55
- elif ("norm" in pt_key and (pt_tuple_key [- 1 ] == "bias" )
53
+ and (pt_tuple_key [:- 1 ] + ("bias" , ) not in random_flax_state_dict )
56
54
and (pt_tuple_key [:- 1 ] + ("scale" , ) in random_flax_state_dict )):
57
55
pt_tuple_key = pt_tuple_key [:- 1 ] + ("scale" , )
58
- pt_tensor = pt_tensor [None , None , None , :]
59
56
elif pt_tuple_key [- 1 ] in [
60
57
"weight" , "gamma"
61
58
] and pt_tuple_key [:- 1 ] + ("scale" , ) in random_flax_state_dict :
62
59
pt_tuple_key = pt_tuple_key [:- 1 ] + ("scale" , )
63
- pt_tensor = pt_tensor [None , None , None , :]
64
60
if pt_tuple_key [- 1 ] == "weight" and pt_tuple_key [:- 1 ] + (
65
61
"embedding" , ) in random_flax_state_dict :
66
62
pt_tuple_key = pt_tuple_key [:- 1 ] + ("embedding" , )
0 commit comments