|
7 | 7 | import copy
|
8 | 8 | from jax.tree_util import tree_map
|
9 | 9 |
|
10 |
| -""" |
11 |
| -Jax default parameter structure: |
12 |
| -dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table']) |
13 |
| -
|
14 |
| -Pytorch stateduct structure: |
15 |
| -dict_keys(['embedding_chunk_0', 'embedding_chunk_1', 'embedding_chunk_2', 'embedding_chunk_3', 'bot_mlp.0.weight', 'bot_mlp.0.bias', 'bot_mlp.2.weight', 'bot_mlp.2.bias', 'bot_mlp.4.weight', 'bot_mlp.4.bias', 'top_mlp.0.weight', 'top_mlp.0.bias', 'top_mlp.2.weight', 'top_mlp.2.bias', 'top_mlp.4.weight', 'top_mlp.4.bias', 'top_mlp.6.weight', 'top_mlp.6.bias', 'top_mlp.8.weight', 'top_mlp.8.bias']) |
16 | 10 |
|
| 11 | +def use_pytorch_weights(file_name: str): |
| 12 | + """ |
| 13 | + Jax default parameter structure: |
| 14 | + dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table']) |
17 | 15 |
|
| 16 | + Pytorch stateduct structure: |
| 17 | + dict_keys(['embedding_chunk_0', 'embedding_chunk_1', 'embedding_chunk_2', 'embedding_chunk_3', 'bot_mlp.0.weight', 'bot_mlp.0.bias', 'bot_mlp.2.weight', 'bot_mlp.2.bias', 'bot_mlp.4.weight', 'bot_mlp.4.bias', 'top_mlp.0.weight', 'top_mlp.0.bias', 'top_mlp.2.weight', 'top_mlp.2.bias', 'top_mlp.4.weight', 'top_mlp.4.bias', 'top_mlp.6.weight', 'top_mlp.6.bias', 'top_mlp.8.weight', 'top_mlp.8.bias']) |
18 | 18 |
|
19 |
| -The following function converts the PyTorch weights to the Jax format |
20 |
| -and assigns them to the Jax model parameters. |
21 |
| -The function assumes that the Jax model parameters are already initialized |
22 |
| -and that the PyTorch weights are in the correct format. |
23 |
| -""" |
24 | 19 |
|
25 |
| -def use_pytorch_weights(file_name: str): |
| 20 | + The following function converts the PyTorch weights to the Jax format |
| 21 | + """ |
| 22 | + |
26 | 23 | jax_copy = {}
|
27 | 24 |
|
28 | 25 | # Load PyTorch state_dict lazily to CPU
|
|
0 commit comments