Skip to content

Commit 8bbe5e2

Browse files
committed
better docstring for the use_pytorch_weights function
1 parent 7a0f8db commit 8bbe5e2

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

custom_pytorch_jax_converter.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,19 @@
77
import copy
88
from jax.tree_util import tree_map
99

10-
"""
11-
Jax default parameter structure:
12-
dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table'])
13-
14-
Pytorch stateduct structure:
15-
dict_keys(['embedding_chunk_0', 'embedding_chunk_1', 'embedding_chunk_2', 'embedding_chunk_3', 'bot_mlp.0.weight', 'bot_mlp.0.bias', 'bot_mlp.2.weight', 'bot_mlp.2.bias', 'bot_mlp.4.weight', 'bot_mlp.4.bias', 'top_mlp.0.weight', 'top_mlp.0.bias', 'top_mlp.2.weight', 'top_mlp.2.bias', 'top_mlp.4.weight', 'top_mlp.4.bias', 'top_mlp.6.weight', 'top_mlp.6.bias', 'top_mlp.8.weight', 'top_mlp.8.bias'])
1610

11+
def use_pytorch_weights(file_name: str):
12+
"""
13+
Jax default parameter structure:
14+
dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table'])
1715
16+
Pytorch stateduct structure:
17+
dict_keys(['embedding_chunk_0', 'embedding_chunk_1', 'embedding_chunk_2', 'embedding_chunk_3', 'bot_mlp.0.weight', 'bot_mlp.0.bias', 'bot_mlp.2.weight', 'bot_mlp.2.bias', 'bot_mlp.4.weight', 'bot_mlp.4.bias', 'top_mlp.0.weight', 'top_mlp.0.bias', 'top_mlp.2.weight', 'top_mlp.2.bias', 'top_mlp.4.weight', 'top_mlp.4.bias', 'top_mlp.6.weight', 'top_mlp.6.bias', 'top_mlp.8.weight', 'top_mlp.8.bias'])
1818
19-
The following function converts the PyTorch weights to the Jax format
20-
and assigns them to the Jax model parameters.
21-
The function assumes that the Jax model parameters are already initialized
22-
and that the PyTorch weights are in the correct format.
23-
"""
2419
25-
def use_pytorch_weights(file_name: str):
20+
The following function converts the PyTorch weights to the Jax format
21+
"""
22+
2623
jax_copy = {}
2724

2825
# Load PyTorch state_dict lazily to CPU

0 commit comments

Comments
 (0)