Skip to content

added example for bidirectional checkpoint testing #1540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

wesleytruong
Copy link
Contributor

This pr adds

  • an example script for bidirectional testing of checkpoint conversion scripts
  • a checkpoint_conversion.md to describe our methodology.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 6, 2025
logger.info(f"Loading chkpt at: {checkpoint_path}")
load_from_hf = False
for filename in os.listdir(checkpoint_path):
if filename == "model.safetensors.index.json":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not reliable -- if there is only one .safetensors file, there won't be such index file

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename it to README.md (so it's displayed when entering this folder)

### Sanity Check (Greedy Decode)
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For greedy decoding, the `generation/test_generate.py` script can be used.

Note that the model definitions can be influenced by external factors than correctness of weight conversion. For example, using our verified `convert_to_hf.py` script then running greedy decoding using HF `transformers` without a correct `config.json` will result in a **false negative** since our weights are correct but the model definition is incorrect due to `config.json`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too obscure than it needs to be. Don't need say what could go wrong, say what they need to do to get right.
E.g. you can just say in order to use HF transformers model, one needs to feed a correct config.json. Remove the "false negative" part.

## Methods

### Sanity Check (Greedy Decode)
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For greedy decoding, the `generation/test_generate.py` script can be used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test_generate is for llama 3 only. In general we don't have such thing, and shouldn't rely on them anyways. So if you still would like to have it here, let's explicitly say something like "it's only available for llama 3, but the methodology is general".

### Comprehensive Check (KL Divergence)
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step.

In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. We additionally compare the conversions done with no permutation to double check that our permutation results in a lower kl divergence loss.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to provide some context on why this permutation is needed in the first place. Otherwise people will get confused why you mention it at all.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks quite good. Had some final comments.

state_dict.pop(k, None)

# Checkpoint Loading
logger.info(f"Loading chkpt at: {checkpoint_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info(f"Loading chkpt at: {checkpoint_path}")
logger.info(f"Loading checkpoint at: {checkpoint_path}")

Comment on lines 111 to 118
hf_model_path = "outputs/checkpoint/step-0-tohf"
hf_model_path_no_perm = "outputs/checkpoint/step-0-tohfnoperm"

# tt params
config_path = "torchtitan/models/llama3/train_configs/llama3_8b.toml"
baseline_checkpoint_path = "outputs/checkpoint/step-0-fromllama"
checkpoint_path = "outputs/checkpoint/step-0-fromhf"
checkpoint_path_no_perm = "outputs/checkpoint/step-0-fromhfnoperm"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add comment: what are these checkpoints and how they are generated / downloaded?
The point is -- for any one working on a new model, they know what to do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now reads very well!

### Comprehensive Check (KL Divergence)
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step.

In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a kl divergence test can reveal subtle inaccuracies such as this, we additionally compare the kl divergence between the original and converted model with and without the permutation. The results are as follows:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a kl divergence test can reveal subtle inaccuracies such as this, we additionally compare the kl divergence between the original and converted model with and without the permutation. The results are as follows:
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's name it numerical_tests_example.py

if __name__ == "__main__":
# hf params
hf_model_name = "meta-llama/Meta-Llama-3-8B"
hf_model_path = "outputs/test_checkpoint/step-0-tohf" # safetensors checkpoint from convert_from_hf.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought convert_from_hf should give us DCP, convert_to_hf should give us safetensors. How come the comments are saying reverted?


# tt params
config_path = "torchtitan/models/llama3/train_configs/llama3_8b.toml"
baseline_checkpoint_path = "outputs/test_checkpoint/step-0-fromllama" # dcp checkpoint from convert_from_llama.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should simplify this script:

  1. we should delete convert_from_llama.py as its functionality has been fully covered with what we recently developed
  2. In general, for models we support, we wouldn't have two sources of non-titan checkpoint -- we'd only have one on HF. So the test logic should be
  • baseline: HF forward
  • test: convert_from_hf -> run torchtitan forward

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Had some final comments and a question for @fegin .

Please also add a pointer to this folder, in parallel to https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/README.md?plain=1#L67-L69


# Checkpoint Loading
logger.info(f"Loading checkpoint at: {checkpoint_path}")
dcp.load(state_dict, checkpoint_id=checkpoint_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait I thought you'd need
https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L437

maybe in this case it's not needed because the pointer to params didn't change?
cc @fegin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model is not wrapped with ModelWrapper and the logic doesn't call distributed state_dict either, so technically, load_state_dict() is not required. dcp.load() will perform the inplace update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wesleytruong wait I'm more confused.
In convert_from_hf.py you used ModelWrapper for loading HF and saving.
Here you call state_dict = model.state_dict() without ModelWrapper -- why a checkpoint from the wrapped state dict can be loaded into the non-wrapped state dict? Are they interchangeable??

Copy link
Contributor Author

@wesleytruong wesleytruong Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wesleytruong wait I'm more confused. In convert_from_hf.py you used ModelWrapper for loading HF and saving. Here you call state_dict = model.state_dict() without ModelWrapper -- why a checkpoint from the wrapped state dict can be loaded into the non-wrapped state dict? Are they interchangeable??

Sorry for the confusion, as Chien-Chin said, they end up with the same result since it's not distributed. From what I understand under the hood ModelWrapper calls PTD's get_model_state_dict which handles getting the state dict of a sharded model, and dcp.load handles loading state dict from a sharded checkpoint. DCP can go from unsharded/sharded checkpoint to unsharded state dict, and both model.state_dict() and ModelWrapper's get_model_state_dict are full state dicts if model is unsharded, so that's why this works.

Either way, I should either change both to follow ModelWrapper or model for consistency. Which one would you prefer?
@tianyu-l

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change to ModelWrapper for consistency, as user may extend this example and run it with compile / AC / distributed where the wrapper is necessary.

### Comprehensive Check (KL Divergence)
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step.

In our `./scripts/checkpoint_conversion/numerical_test_example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. This script only tests `from_hf` direction since loading a huggingface checkpoint requires correctly converting the instantiated `torchtitan` state dict `to_hf` so that safetensors weights can be loaded into it. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
In our `./scripts/checkpoint_conversion/numerical_test_example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. This script only tests `from_hf` direction since loading a huggingface checkpoint requires correctly converting the instantiated `torchtitan` state dict `to_hf` so that safetensors weights can be loaded into it. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows:
In our `./scripts/checkpoint_conversion/numerical_test_example.py` this will be performing forward on DCP checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in HuggingFace `AutoModelForCausalLM`. This script tests the HuggingFace -> `torchtitan` direction, as loading a HuggingFace checkpoint requires both
- converting the instantiated `torchtitan` state dict `to_hf` so that safetensors weights can be loaded into it, and
- converting the HF version of state dict back to torchtitan using `from_hf`.
To convert Llama 3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows:

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a question.

Please also address

Please also add a pointer to this folder, in parallel to https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/README.md?plain=1#L67-L69


# Checkpoint Loading
logger.info(f"Loading checkpoint at: {checkpoint_path}")
dcp.load(state_dict, checkpoint_id=checkpoint_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wesleytruong wait I'm more confused.
In convert_from_hf.py you used ModelWrapper for loading HF and saving.
Here you call state_dict = model.state_dict() without ModelWrapper -- why a checkpoint from the wrapped state dict can be loaded into the non-wrapped state dict? Are they interchangeable??

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants