Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long sequence samples exceeding max_seq_len are resulting in CUDA OOM while performing LoRA SFT #471

Open
Akash-Nayak opened this issue Feb 17, 2025 · 0 comments

Comments

@Akash-Nayak
Copy link

Describe the bug

While running LoRA SFT the training failed with CUDA out of memory error at step 164.

Platform

Please provide details about the environment you are using, including the following:

  • Interpreter version: Python 3.12
  • Library version: fms-hf-tuning commit f1fd130

Sample Code

accelerate launch \
  --num_processes=8 \
  --dynamo_backend="no" \
  --fsdp_auto_wrap_policy="TRANSFORMER_BASED_WRAP" \
  --fsdp_cpu_ram_efficient_loading="true" \
  --fsdp_forward_prefetch="false" \
  --fsdp_offload_params="false" \
  --fsdp_sharding_strategy="HYBRID_SHARD" \
  --fsdp_state_dict_type="FULL_STATE_DICT" \
  --fsdp_sync_module_states="true" \
  --machine_rank="${RANK}" \
  --main_process_ip="${MASTER_ADDR}" \
  --main_process_port="${MASTER_PORT}" \
  --mixed_precision="no" \
  --num_machines="${WORLD_SIZE}" \
  --rdzv_backend="static" \
  --same_network \
  --use_fsdp \
  -m tuning.sft_trainer \
  --adam_beta1="0.9" \
  --adam_beta2="0.98" \
  --adam_epsilon="1e-10" \
  --aim_repo="${AIMSTACK_DB}" \
  --data_config="data_config.yaml" \
  --dataloader_drop_last="true" \
  --evaluation_strategy="no" \
  --experiment="lora-sft-vm-7b0512aa-c957-419d-9bfc-8f142f62ecfe" \
  --gradient_accumulation_steps="1" \
  --gradient_checkpointing="true" \
  --include_tokens_per_second="true" \
  --learning_rate="3e-05" \
  --logging_steps="10" \
  --logging_strategy="steps" \
  --lora_alpha="32" \
  --lora_dropout="0.1" \
  --lr_scheduler_type="cosine" \
  --max_seq_len="8192" \
  --max_steps="8000" \
  --model_name_or_path="ibm-granite/granite-3.1-8b-instruct" \
  --optim="adamw_torch" \
  --output_dir="/cos-mount/peft/lora/granite-3.1-8b-instruct/lora-9/" \
  --packing="False" \
  --peft_method="lora" \
  --per_device_train_batch_size="8" \
  -r="16" \
  --save_steps="1000" \
  --save_strategy="steps" \
  --split_batches="true" \
  --target_modules all-linear \
  --tracker="aim" \
  --use_flash_attn="true" \
  --use_reentrant="true" \
  --warmup_ratio="0.1" \
  --warmup_steps="2" \
  --weight_decay="0.1"

Dataprocessor

    dataprocessor:
        type: default
        sampling_stopping_strategy: all_exhausted
        seed: 66

Datahandlers

  data_handlers:
          - name: tokenize_and_apply_input_masking
            arguments:
              remove_columns: all
              batched: false
              fn_kwargs:
                input_field_name: input
                output_field_name: output

Expected behavior

The training failed after running for a few steps probably after encountering a sample that exceeds the specified max_seq_len (8192). The training should continue to run even after it encounters samples with longer sequence length.

Observed behavior

  2%|▏         | 164/8000 [08:59<6:06:33,  2.81s/it]ERROR:sft_trainer.py:Traceback (most recent call last):
  File "/home/tuning/.local/lib/python3.12/site-packages/tuning/sft_trainer.py", line 676, in main
    trainer, additional_train_info = train(
                                     ^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/tuning/sft_trainer.py", line 420, in train
    trainer.train(resume_from_checkpoint)
  File "/home/tuning/.local/lib/python3.12/site-packages/trl/trainer/sft_trainer.py", line 434, in train
    output = super().train(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/transformers/trainer.py", line 2052, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/transformers/trainer.py", line 2388, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/transformers/trainer.py", line 3485, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/transformers/trainer.py", line 3532, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/peft/peft_model.py", line 1644, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/transformers/models/granite/modeling_granite.py", line 1087, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/transformers/models/granite/modeling_granite.py", line 880, in forward
    layer_outputs = self._gradient_checkpointing_func(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 255, in forward
    outputs = run_function(*args)
              ^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/transformers/models/granite/modeling_granite.py", line 623, in forward
    hidden_states = self.mlp(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/transformers/models/granite/modeling_granite.py", line 239, in forward
    return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
                                                           ^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.12/site-packages/peft/tuners/lora/layer.py", line 584, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.02 GiB. GPU 7 has a total capacity of 79.14 GiB of which 6.06 GiB is free. Process 1069103 has 73.07 GiB memory in use. Of the allocated memory 61.20 GiB is allocated by PyTorch, and 9.26 Gi
B is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cu
da.html#environment-variables)

Additional context

It could be happening because the data handler is not truncating the longer sequences exceeding max_seq_len and thus leading to OOM error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant