Skip to content

Commit e3e1be8

Browse files
committed
added example for bidirectional checkpoint testing
1 parent 36ec547 commit e3e1be8

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Testing Checkpoint Conversion for Correctness
2+
3+
When converting checkpoints between file types or model definitions, we need to ensure that the converted checkpoints are correct, i.e. their model definition remains the same, which includes that the converted checkpoint's weights will give the same outputs when loaded in the new intended program context.
4+
5+
This guide provides a general framework on how to test your conversion script for correctness. The example that we will use here is bidirectional conversion between HuggingFace and `torchtitan`.
6+
7+
## Methods
8+
9+
### Sanity Check (Greedy Decode)
10+
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For greedy decoding, the `generation/test_generate.py` script can be used.
11+
12+
Note that the model definitions can be influenced by external factors than correctness of weight conversion. For example, using our verified `convert_to_hf.py` script then running greedy decoding using HF `transformers` without a correct `config.json` will result in a **false negative** since our weights are correct but the model definition is incorrect due to `config.json`.
13+
14+
### Comprehensive Check (KL Divergence)
15+
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step.
16+
17+
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. We additionally compare the conversions done with no permutation to double check that our permutation results in a lower kl divergence loss.
18+
19+
```
20+
$ python ./scripts/checkpoint_conversion/example.py
21+
Average loss for test from_hf is -4.951488641303202e-14
22+
Average loss for test to_hf is -4.951488641303202e-14
23+
Average loss for test from_hf_no_perm is 6.310602202574955e-06
24+
Average loss for test to_hf_no_perm is 2.0396773834363557e-05
25+
```
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import json
2+
import os
3+
import sys
4+
from pathlib import Path
5+
from typing import Optional
6+
7+
import torch
8+
9+
import torch.distributed.checkpoint as dcp
10+
import torch.nn.functional as F
11+
from torch.distributed.checkpoint import HuggingFaceStorageReader
12+
from torchtitan.components.checkpoint import excluded_parameters_for_model_only
13+
from torchtitan.config import ConfigManager
14+
from torchtitan.protocols.train_spec import get_train_spec
15+
from torchtitan.tools.logging import logger
16+
from transformers import AutoModelForCausalLM
17+
18+
device_type = "cuda" if torch.cuda.is_available() else "cpu"
19+
20+
21+
def loss_fn(logits1, logits2):
22+
# Convert logits to probabilities
23+
probs1 = F.log_softmax(logits1, dim=-1)
24+
probs2 = F.softmax(logits2, dim=-1)
25+
26+
# Calculate KL Divergence
27+
kl_loss = F.kl_div(probs1, probs2, "mean")
28+
return kl_loss
29+
30+
31+
@torch.no_grad
32+
def forward_hf(model_name, model_path: Optional[str], input_ids):
33+
# Load the tokenizer and model
34+
model_path = model_path if model_path else model_name
35+
model = AutoModelForCausalLM.from_pretrained(model_path)
36+
37+
device = torch.device(device_type)
38+
model.to(device)
39+
40+
# List to store outputs
41+
outputs_list = []
42+
43+
for inputs in input_ids:
44+
inputs = inputs.to(device)
45+
outputs = model.generate(
46+
inputs=inputs,
47+
max_length=prompt_len + 1,
48+
do_sample=False,
49+
output_logits=True,
50+
return_dict_in_generate=True,
51+
)
52+
53+
outputs = torch.stack(outputs.logits)
54+
outputs_list.append(outputs)
55+
56+
del model
57+
torch.cuda.empty_cache()
58+
59+
return outputs_list
60+
61+
62+
@torch.no_grad
63+
def forward_tt(config_path, checkpoint_path, test_set):
64+
65+
config_manager = ConfigManager()
66+
config = config_manager.parse_args([f"--job.config_file={config_path}"])
67+
68+
train_spec = get_train_spec(config.model.name)
69+
70+
model_args = train_spec.model_args[config.model.flavor]
71+
model_args.update_from_config(config)
72+
73+
model = train_spec.model_cls(model_args)
74+
75+
# materalize model
76+
device = torch.device(device_type)
77+
model.to_empty(device=device)
78+
with torch.no_grad():
79+
model.init_weights()
80+
model.eval()
81+
82+
state_dict = model.state_dict()
83+
for k in excluded_parameters_for_model_only:
84+
state_dict.pop(k, None)
85+
86+
# Checkpoint Loading
87+
logger.info(f"Loading chkpt at: {checkpoint_path}")
88+
load_from_hf = False
89+
for filename in os.listdir(checkpoint_path):
90+
if filename == "model.safetensors.index.json":
91+
load_from_hf = True
92+
if load_from_hf:
93+
sd_adapter = train_spec.state_dict_adapter
94+
hf_state_dict = sd_adapter.to_hf(state_dict)
95+
dcp.load(hf_state_dict, HuggingFaceStorageReader(path=checkpoint_path))
96+
state_dict = sd_adapter.from_hf(hf_state_dict)
97+
else:
98+
dcp.load(state_dict, checkpoint_id=checkpoint_path)
99+
100+
output_list = []
101+
for prompt in test_set:
102+
input_ids = prompt.to(device_type)
103+
# ensure batch dimension (T,) --> (B, T)
104+
if input_ids.ndim == 1:
105+
input_ids = input_ids.unsqueeze(0)
106+
107+
# obtains the logits of only the last token in the predictions
108+
predictions = model(input_ids)[:, -1, :].unsqueeze(1)
109+
output_list.append(predictions)
110+
111+
del model
112+
torch.cuda.empty_cache()
113+
114+
return output_list
115+
116+
117+
if __name__ == "__main__":
118+
# hf params
119+
hf_model_name = "meta-llama/Meta-Llama-3-8B"
120+
hf_model_path = "outputs/checkpoint/step-0-tohf"
121+
hf_model_path_no_perm = "outputs/checkpoint/step-0-tohfnoperm"
122+
123+
# tt params
124+
config_path = "torchtitan/models/llama3/train_configs/llama3_8b.toml"
125+
baseline_checkpoint_path = "outputs/checkpoint/step-0-fromllama"
126+
checkpoint_path = "outputs/checkpoint/step-0-fromhf"
127+
checkpoint_path_no_perm = "outputs/checkpoint/step-0-fromhfnoperm"
128+
129+
# test params
130+
prompt_len = 8
131+
test_size = 100
132+
133+
config_manager = ConfigManager()
134+
config = config_manager.parse_args([f"--job.config_file={config_path}"])
135+
train_spec = get_train_spec(config.model.name)
136+
tokenizer = train_spec.build_tokenizer_fn(config)
137+
138+
# Build test set of randomly generated token ids
139+
test_set = [
140+
torch.randint(
141+
0,
142+
tokenizer.get_vocab_size(),
143+
(
144+
1, # batch size
145+
prompt_len,
146+
),
147+
)
148+
for _ in range(test_size)
149+
]
150+
151+
# baseline logits
152+
baseline_hf_outputs = forward_hf(hf_model_name, None, test_set)
153+
baseline_tt_outputs = forward_tt(config_path, baseline_checkpoint_path, test_set)
154+
155+
# testing from hf script
156+
from_hf_outputs = forward_tt(config_path, checkpoint_path, test_set)
157+
from_hf_outputs_no_perm = forward_tt(config_path, checkpoint_path_no_perm, test_set)
158+
159+
# testing to hf script
160+
to_hf_outputs = forward_hf(hf_model_name, hf_model_path, test_set)
161+
to_hf_outputs_no_perm = forward_hf(hf_model_name, hf_model_path_no_perm, test_set)
162+
163+
# Define the set of outputs to test loss for
164+
test_configs = {
165+
"from_hf": [baseline_hf_outputs, from_hf_outputs],
166+
"to_hf": [to_hf_outputs, baseline_tt_outputs],
167+
"from_hf_no_perm": [baseline_hf_outputs, from_hf_outputs_no_perm],
168+
"to_hf_no_perm": [to_hf_outputs_no_perm, baseline_tt_outputs],
169+
}
170+
avg_losses = {}
171+
172+
for test_name, (hf, tt) in test_configs.items():
173+
total_loss = 0
174+
for hf, tt in zip(hf, tt):
175+
total_loss += loss_fn(hf, tt)
176+
avg_loss = total_loss / len(test_set)
177+
avg_losses[test_name] = avg_loss.item()
178+
179+
for test_name, avg_loss in avg_losses.items():
180+
print(f"Average loss of test {test_name} is {avg_loss}")

0 commit comments

Comments
 (0)