Skip to content

Finetune from pre-trained models #1300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

vwxyzjn
Copy link

@vwxyzjn vwxyzjn commented Jun 15, 2025

This PR adds two main changes:

  1. add a max_seq_len allowing the model to load a pre-trained llama 3.1 8B checkpoint. Note that I had to revert to the old checkpoint code. Otherwise, I got a weird error trace shown at the bottom of this PR description.
  2. allow for starting from a checkpoint without enable_checkpoint. Use case: the user might want to do fine-tuning without saving intermediate checkpoints.

Tested with the following commands:

# Download the tokenizer and model weights
rm -rf tmp
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/tokenizer.model --local-dir tmp
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/consolidated.00.pth --local-dir tmp
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/params.json --local-dir tmp
# Convert the model weights to the DCP format and move it and the tokenizer to the assets folder
mkdir -p assets/tokenizer && cp tmp/original/tokenizer.model assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model
uv run python -m scripts.convert_llama_to_dcp tmp/original/ assets/models/dcp/llama3.1-8B

Then you can fine-tune from the checkpoint:

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" uv run ./run_train.sh \
  --model.tokenizer_path assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model \
  --training.max_seq_len 131072 \
  --checkpoint.initial_load_path "assets/models/dcp/llama3.1-8B" \
  --profiling.no_enable_profiling \
  --activation_checkpoint.mode full \
  --training.global_batch_size 64 \
  --lr_scheduler.warmup_steps 40 \
  --optimizer.lr 1e-5
image

Error trace with the new load checkpoint code

If I don't revert back to the old checkpointing code I would get

    File "/home/ubuntu/code/thirdparty/torchtitan/.venv/lib/python3.13/site-packages/torch/19:41:24 [255/770]
oint/utils.py", line 465, in inner_func
      return func(*args, **kwargs)
    File "/home/ubuntu/code/thirdparty/torchtitan/.venv/lib/python3.13/site-packages/torch/distributed/checkp
oint/state_dict_loader.py", line 177, in load
      _load_state_dict(
      ~~~~~~~~~~~~~~~~^
          state_dict=statetful_sd,
          ^^^^^^^^^^^^^^^^^^^^^^^^
      ...<3 lines>...
          planner=planner,
          ^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ubuntu/code/thirdparty/torchtitan/.venv/lib/python3.13/site-packages/torch/distributed/checkp
oint/state_dict_loader.py", line 234, in _load_state_dict
      central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
                               ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ubuntu/code/thirdparty/torchtitan/.venv/lib/python3.13/site-packages/torch/distributed/checkp
oint/utils.py", line 196, in reduce_scatter
      all_data = self.gather_object(local_data)
    File "/home/ubuntu/code/thirdparty/torchtitan/.venv/lib/python3.13/site-packages/torch/distributed/checkp
oint/utils.py", line 135, in gather_object
      dist.gather_object(
      ~~~~~~~~~~~~~~~~~~^
          obj=object,
          ^^^^^^^^^^^
      ...<2 lines>...
          group=self.group,
          ^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ubuntu/code/thirdparty/torchtitan/.venv/lib/python3.13/site-packages/torch/distributed/c10d_l
ogger.py", line 81, in wrapper
      return func(*args, **kwargs)
    File "/home/ubuntu/code/thirdparty/torchtitan/.venv/lib/python3.13/site-packages/torch/distributed/distri
buted_c10d.py", line 3139, in gather_object
      input_tensor, local_size = _object_to_tensor(obj, current_device, group)
                                 ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ubuntu/code/thirdparty/torchtitan/.venv/lib/python3.13/site-packages/torch/distributed/distri
buted_c10d.py", line 2935, in _object_to_tensor
      _pickler(f).dump(obj)
      ~~~~~~~~~~~~~~~~^^^^^
  TypeError: cannot pickle code objects

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 15, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having an instruction / a usage example for the script convert_llama_to_dcp could be helpful.

allow for starting from a checkpoint without enable_checkpoint. Use case: the user might want to do fine-tuning without saving intermediate checkpoints.

We can think more about the UI, e.g. separate enable_load from enable_save. However, in your case, can't you just specify the interval to be a very large number?

I disabled it partially because it takes 6 mins to save an 8B model w/ non-async mode.

We will need to root cause and solve the issue.

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" uv run ./run_train.sh \
--model.tokenizer_path assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model \
--training.max_seq_len 131072 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's necessary to create this config -- how is it different from specifying --training.seq_len 131072?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One example use case is when I don't actually have documents up to seq_len 131072, but the pre-trained model has a default seq_len of 131072.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the seq_len field is only used when generating freqs_cis https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model/model.py#L397
This should be input agnostic, so I feel you can just specify --training.seq_len to be however long you need (as long as it doesn't exceed model capability).
Let me know if it's not the case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see the problem. The issue is that HuggingFaceDataset uses --training.seq_len, so the packed dataset also has the same length.

In that case, we should prob re-use the same seq_len, but allowing the HuggingFaceDataset to use a separate packed_len. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question is why you'd wish them to be different.
The requirement is that: if HF dataset uses seq_len_hf, then we need to have seq_len_transformer >= seq_len_hf to make sure the freqs_cis is init with enough length.
But we don't need seq_len_transformer > seq_len_hf (or do we?), so it can just be seq_len_transformer = seq_len_hf = training.seq_len.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am worried about setting seq_len_transformer=131072 would make training OOM vs seq_len_transformer =8192.

However, it appears I need to set seq_len_transformer=131072 if I am trying to load a pretrained model such as llama 3.1 8B. Is this correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it appears I need to set seq_len_transformer=131072 if I am trying to load a pretrained model such as llama 3.1 8B. Is this correct?

oh I see your worry.

I don't think it should be the case. Like I said, the only place the seq_len matters in transformers is for freqs_cis which is a non-persistent buffer and shouldn't be included the model checkpoint.
(Previously in torchtitan it could be, but after https://github.com/pytorch/torchtitan/pull/1236/files#diff-27a108fa6d4885d9c66306785cb36029c0b4f5a1542e63ae24e84eb7e9a273d1R87 it shouldn't.)

For your finetuning job, the model capability shouldn't be affected by specifying a smaller max_seq_len.

BTW, you could consider use CP in torchtitan for long sequence finetuning.

Copy link
Author

@vwxyzjn vwxyzjn Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For your finetuning job, the model capability shouldn't be affected by specifying a smaller max_seq_len.

I guess an important question is this: if we have a pretrained model with seq_len=131072, should we always compute freqs_cis using seq_len=131072?

If the answer is yes, it would make sense to set up an arg called seq_len_transformer (in place of my current max_seq_len, and set it to 131072 when loading llama 3.1 8B.

I see. It looks like because of how freqs_cis is used in reshape_for_broadcast, it's fine if we calculate it without the full 131072. Then it doesn't make sense to save / load from it.

Thanks. I will adjust the PR accordingly.

image

export HF_TOKEN=... # get your HF token from https://huggingface.co/settings/tokens
# Download the tokenizer and model weights
rm -rf tmp
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/tokenizer.model --local-dir tmp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We covered the downloading of tokenizer above, in section "Downloading a tokenizer".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah true, I was gonna ask do you want to replace that with huggingface-cli commands? We could use it for both downloading tokenizer and the actual models.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, maybe let's first put the complete huggingface-cli flow inside finetune.md. If people get used to it, we can change the version in main README later.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

@@ -114,6 +114,36 @@ Llama 3 8B model locally on 8 GPUs
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
```

### Fine-tuning from an existing checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this under docs/finetune.md instead of main README? We can create a link to the doc around here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course. Will do.

Comment on lines +126 to +130
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/consolidated.00.pth --local-dir tmp
uv run huggingface-cli download meta-llama/Llama-3.1-8B original/params.json --local-dir tmp
# Convert the model weights to the DCP format and move it and the tokenizer to the assets folder
mkdir -p assets/tokenizer && cp tmp/original/tokenizer.model assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model
uv run python -m scripts.convert_llama_to_dcp tmp/original/ assets/models/dcp/llama3.1-8B
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using uv is fine but as general instruction we shouldn't assume users have to use uv

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of tmp and assets/models/dcp which looks arbitrarily chose, let's try to use generic placeholders.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah forgot to remove the uv part. What do you mean by generic placeholders?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like, instead of tmp, use [original_model_dir], dcp_model_dir, [tokenizer_dir] so that people know what to replace

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh we shouldn't just revert the changes -- instead we should investigate the root cause

cc @fegin pls take a look at if recent changes break anything

@lkhphuc
Copy link
Contributor

lkhphuc commented Jun 16, 2025

    File "/home/ubuntu/code/thirdparty/torchtitan/.venv/lib/python3.13/site-packages/torch/distributed/distri
buted_c10d.py", line 2935, in _object_to_tensor
      _pickler(f).dump(obj)
      ~~~~~~~~~~~~~~~~^^^^^
  TypeError: cannot pickle code objects

I find that if you use python 3.13, any error in loading the checkpoint like missing keys, shape mismatch will always results in this error.
If you use with python 3.12 or smaller, it would throws the actual error why it fails.

In addition, you can not load a checkpoint created in venv with python 3.13 in a venv with python 3.12. It resulted in some internal python error _Pathlib.___ something missing.

@fegin
Copy link
Contributor

fegin commented Jun 16, 2025

For checkpointing, I don't think we should separate enable_save from enable_load. It is too fine-grain and I don't think there is a real production use case, you will still need to save the final checkpoint anyway.

The first step checkpoint is always enabled to allow users to quick understand if there are errors. In the past, several users complain that they have to train X steps before they discover there are checkpointing issues. I think we can disable this feature or make it configurable.

Async checkpointing being slow is a bug we need to figure out.

@vwxyzjn
Copy link
Author

vwxyzjn commented Jun 16, 2025

The first step checkpoint is always enabled to allow users to quick understand if there are errors.

This is such a great point! My first reaction was "why am I saving on step 1?"

I wonder if we should log something like "saving a first checkpoint to ensure it works".

@fegin
Copy link
Contributor

fegin commented Jun 17, 2025

You can consider to use #1310 to avoid checkpoint overhead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants