-
Notifications
You must be signed in to change notification settings - Fork 476
added example for bidirectional checkpoint testing #1540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
e3e1be8
to
f5f9f14
Compare
logger.info(f"Loading chkpt at: {checkpoint_path}") | ||
load_from_hf = False | ||
for filename in os.listdir(checkpoint_path): | ||
if filename == "model.safetensors.index.json": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not reliable -- if there is only one .safetensors file, there won't be such index file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename it to README.md (so it's displayed when entering this folder)
### Sanity Check (Greedy Decode) | ||
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For greedy decoding, the `generation/test_generate.py` script can be used. | ||
|
||
Note that the model definitions can be influenced by external factors than correctness of weight conversion. For example, using our verified `convert_to_hf.py` script then running greedy decoding using HF `transformers` without a correct `config.json` will result in a **false negative** since our weights are correct but the model definition is incorrect due to `config.json`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too obscure than it needs to be. Don't need say what could go wrong, say what they need to do to get right.
E.g. you can just say in order to use HF transformers
model, one needs to feed a correct config.json
. Remove the "false negative" part.
## Methods | ||
|
||
### Sanity Check (Greedy Decode) | ||
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For greedy decoding, the `generation/test_generate.py` script can be used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test_generate
is for llama 3 only. In general we don't have such thing, and shouldn't rely on them anyways. So if you still would like to have it here, let's explicitly say something like "it's only available for llama 3, but the methodology is general".
### Comprehensive Check (KL Divergence) | ||
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step. | ||
|
||
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. We additionally compare the conversions done with no permutation to double check that our permutation results in a lower kl divergence loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to provide some context on why this permutation is needed in the first place. Otherwise people will get confused why you mention it at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks quite good. Had some final comments.
state_dict.pop(k, None) | ||
|
||
# Checkpoint Loading | ||
logger.info(f"Loading chkpt at: {checkpoint_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger.info(f"Loading chkpt at: {checkpoint_path}") | |
logger.info(f"Loading checkpoint at: {checkpoint_path}") |
hf_model_path = "outputs/checkpoint/step-0-tohf" | ||
hf_model_path_no_perm = "outputs/checkpoint/step-0-tohfnoperm" | ||
|
||
# tt params | ||
config_path = "torchtitan/models/llama3/train_configs/llama3_8b.toml" | ||
baseline_checkpoint_path = "outputs/checkpoint/step-0-fromllama" | ||
checkpoint_path = "outputs/checkpoint/step-0-fromhf" | ||
checkpoint_path_no_perm = "outputs/checkpoint/step-0-fromhfnoperm" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add comment: what are these checkpoints and how they are generated / downloaded?
The point is -- for any one working on a new model, they know what to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now reads very well!
### Comprehensive Check (KL Divergence) | ||
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step. | ||
|
||
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a kl divergence test can reveal subtle inaccuracies such as this, we additionally compare the kl divergence between the original and converted model with and without the permutation. The results are as follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a kl divergence test can reveal subtle inaccuracies such as this, we additionally compare the kl divergence between the original and converted model with and without the permutation. The results are as follows: | |
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's name it numerical_tests_example.py
if __name__ == "__main__": | ||
# hf params | ||
hf_model_name = "meta-llama/Meta-Llama-3-8B" | ||
hf_model_path = "outputs/test_checkpoint/step-0-tohf" # safetensors checkpoint from convert_from_hf.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought convert_from_hf
should give us DCP, convert_to_hf
should give us safetensors. How come the comments are saying reverted?
|
||
# tt params | ||
config_path = "torchtitan/models/llama3/train_configs/llama3_8b.toml" | ||
baseline_checkpoint_path = "outputs/test_checkpoint/step-0-fromllama" # dcp checkpoint from convert_from_llama.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should simplify this script:
- we should delete
convert_from_llama.py
as its functionality has been fully covered with what we recently developed - In general, for models we support, we wouldn't have two sources of non-titan checkpoint -- we'd only have one on HF. So the test logic should be
- baseline: HF forward
- test: convert_from_hf -> run torchtitan forward
cd34d7c
to
833785b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Had some final comments and a question for @fegin .
Please also add a pointer to this folder, in parallel to https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/README.md?plain=1#L67-L69
|
||
# Checkpoint Loading | ||
logger.info(f"Loading checkpoint at: {checkpoint_path}") | ||
dcp.load(state_dict, checkpoint_id=checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait I thought you'd need
https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L437
maybe in this case it's not needed because the pointer to params didn't change?
cc @fegin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This model is not wrapped with ModelWrapper and the logic doesn't call distributed state_dict either, so technically, load_state_dict()
is not required. dcp.load()
will perform the inplace update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wesleytruong wait I'm more confused.
In convert_from_hf.py
you used ModelWrapper
for loading HF and saving.
Here you call state_dict = model.state_dict()
without ModelWrapper
-- why a checkpoint from the wrapped state dict can be loaded into the non-wrapped state dict? Are they interchangeable??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wesleytruong wait I'm more confused. In
convert_from_hf.py
you usedModelWrapper
for loading HF and saving. Here you callstate_dict = model.state_dict()
withoutModelWrapper
-- why a checkpoint from the wrapped state dict can be loaded into the non-wrapped state dict? Are they interchangeable??
Sorry for the confusion, as Chien-Chin said, they end up with the same result since it's not distributed. From what I understand under the hood ModelWrapper calls PTD's get_model_state_dict
which handles getting the state dict of a sharded model, and dcp.load
handles loading state dict from a sharded checkpoint. DCP can go from unsharded/sharded checkpoint to unsharded state dict, and both model.state_dict()
and ModelWrapper's get_model_state_dict
are full state dicts if model is unsharded, so that's why this works.
Either way, I should either change both to follow ModelWrapper or model for consistency. Which one would you prefer?
@tianyu-l
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's change to ModelWrapper
for consistency, as user may extend this example and run it with compile / AC / distributed where the wrapper is necessary.
### Comprehensive Check (KL Divergence) | ||
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step. | ||
|
||
In our `./scripts/checkpoint_conversion/numerical_test_example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. This script only tests `from_hf` direction since loading a huggingface checkpoint requires correctly converting the instantiated `torchtitan` state dict `to_hf` so that safetensors weights can be loaded into it. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our `./scripts/checkpoint_conversion/numerical_test_example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. This script only tests `from_hf` direction since loading a huggingface checkpoint requires correctly converting the instantiated `torchtitan` state dict `to_hf` so that safetensors weights can be loaded into it. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows: | |
In our `./scripts/checkpoint_conversion/numerical_test_example.py` this will be performing forward on DCP checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in HuggingFace `AutoModelForCausalLM`. This script tests the HuggingFace -> `torchtitan` direction, as loading a HuggingFace checkpoint requires both | |
- converting the instantiated `torchtitan` state dict `to_hf` so that safetensors weights can be loaded into it, and | |
- converting the HF version of state dict back to torchtitan using `from_hf`. | |
To convert Llama 3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a question.
Please also address
Please also add a pointer to this folder, in parallel to https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/README.md?plain=1#L67-L69
|
||
# Checkpoint Loading | ||
logger.info(f"Loading checkpoint at: {checkpoint_path}") | ||
dcp.load(state_dict, checkpoint_id=checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wesleytruong wait I'm more confused.
In convert_from_hf.py
you used ModelWrapper
for loading HF and saving.
Here you call state_dict = model.state_dict()
without ModelWrapper
-- why a checkpoint from the wrapped state dict can be loaded into the non-wrapped state dict? Are they interchangeable??
This pr adds
checkpoint_conversion.md
to describe our methodology.