-
Notifications
You must be signed in to change notification settings - Fork 417
Fix create_nnx_model() to load NNX checkpoint, verified with vllm_decode #2478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
63e254f
to
9329b82
Compare
21bdda4
to
2d9ebf5
Compare
2d9ebf5
to
16b64f7
Compare
🤖 Hi @hengtaoguo, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
item_to_restore = {"params": {"params": target_for_restore}} | ||
restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}} | ||
else: | ||
# structure of nnx checkpoint: {'decoder': {'value': ...}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For your comments:
# structure of linen checkpoint: {'params': {'params': {'decoder': ...}}}
# structure of nnx checkpoint: {'decoder': {'value': ...}}
Is the ...
part structural identical between linen and nnx? Then for the below variable checkpoint
, what's the tree structure of it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, ...
is the identical part. Here is what checkpoint
variable would look like for both Linen and NNX: https://paste.googleplex.com/4702000029761536
per_device_batch_size=1 run_name=vllm_decode_test \ | ||
use_chat_template=True prompt="Suggest some famous landmarks in London." \ | ||
per_device_batch_size=1 run_name=vllm_decode_test max_target_length=64 \ | ||
use_chat_template=False prompt="Suggest some famous landmarks in London." \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the issue if we use sft.yml
?
|
||
is_nnx_checkpoint = True | ||
if ( | ||
"params" in metadata.item_metadata.tree.keys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the better way to decide if it is an NNX checkpoint is to check for nnx.Variable in the leaf.
target_for_restore = jax.tree.map( | ||
lambda v: v.value, | ||
sharded_state, | ||
is_leaf=lambda n: isinstance(n, nnx.Variable), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it is not an NNX checkpoint i.e., is_nnx_checkpoint = False
, why do we have nnx.Variable
check?
Description
create_nnx_model()
to load both Linen and NNX checkpointTests
Tested loading both linen and NNX checkpoint and compare the params before and after updating the model.
Loading linen checkpoint: https://paste.googleplex.com/5840546769797120
Loading NNX checkpoint (without pathways): https://paste.googleplex.com/5153322256433152
Loading NNX checkpoint (with pathways): https://paste.googleplex.com/5967825172824064
vLLM Decoding Results: http://shortn/_QHI6tVbeYP
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-review
label.