Skip to content

Conversation

SurbhiJainUSC
Copy link
Collaborator

@SurbhiJainUSC SurbhiJainUSC commented Oct 9, 2025

Description

  • Update create_nnx_model() to load both Linen and NNX checkpoint
  • FIXES: b/450116695

Tests

Tested loading both linen and NNX checkpoint and compare the params before and after updating the model.

Loading linen checkpoint: https://paste.googleplex.com/5840546769797120

Loading NNX checkpoint (without pathways): https://paste.googleplex.com/5153322256433152

Loading NNX checkpoint (with pathways): https://paste.googleplex.com/5967825172824064

vLLM Decoding Results: http://shortn/_QHI6tVbeYP

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@SurbhiJainUSC SurbhiJainUSC changed the title Fix create_nnx_model() to load NNX checkpoint Fix create_nnx_model() to load NNX checkpoint, verified with vllm_decode Oct 10, 2025
@SurbhiJainUSC SurbhiJainUSC force-pushed the nnx_ckpt_restore branch 3 times, most recently from 21bdda4 to 2d9ebf5 Compare October 10, 2025 19:48
Copy link

🤖 Hi @hengtaoguo, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

item_to_restore = {"params": {"params": target_for_restore}}
restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}}
else:
# structure of nnx checkpoint: {'decoder': {'value': ...}}
Copy link
Collaborator

@hengtaoguo hengtaoguo Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For your comments:

# structure of linen checkpoint: {'params': {'params': {'decoder': ...}}}
# structure of nnx checkpoint: {'decoder': {'value': ...}}

Is the ... part structural identical between linen and nnx? Then for the below variable checkpoint, what's the tree structure of it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ... is the identical part. Here is what checkpoint variable would look like for both Linen and NNX: https://paste.googleplex.com/4702000029761536

per_device_batch_size=1 run_name=vllm_decode_test \
use_chat_template=True prompt="Suggest some famous landmarks in London." \
per_device_batch_size=1 run_name=vllm_decode_test max_target_length=64 \
use_chat_template=False prompt="Suggest some famous landmarks in London." \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the issue if we use sft.yml?


is_nnx_checkpoint = True
if (
"params" in metadata.item_metadata.tree.keys()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the better way to decide if it is an NNX checkpoint is to check for nnx.Variable in the leaf.

target_for_restore = jax.tree.map(
lambda v: v.value,
sharded_state,
is_leaf=lambda n: isinstance(n, nnx.Variable),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it is not an NNX checkpoint i.e., is_nnx_checkpoint = False, why do we have nnx.Variable check?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants