Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention error when training in latest environment #120

Open
daydayup2100 opened this issue Feb 19, 2025 · 2 comments
Open

Flash attention error when training in latest environment #120

daydayup2100 opened this issue Feb 19, 2025 · 2 comments

Comments

@daydayup2100
Copy link

daydayup2100 commented Feb 19, 2025

The error shown below occurred when I ran the training script. There seems something wrong in the conda environment set according to the latest "setup.sh" , though there was no error during setup. I've only encountered this problem in environments configured according to the current code release , not in environments configured according to the previous one.

Traceback (most recent call last):
  File "/home/user/miniconda3/envs/r1v/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1863, in _get_module
    return importlib.import_module("." + module_name, self.__name__)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/r1v/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/user/miniconda3/envs/r1v/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in <module>
    from ...modeling_utils import PreTrainedModel
  File "/home/user/miniconda3/envs/r1v/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in <module>
    from .integrations.flash_attention import flash_attention_forward
  File "/home/user/miniconda3/envs/r1v/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in <module>
    from ..modeling_flash_attention_utils import _flash_attention_forward
  File "/home/user/miniconda3/envs/r1v/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in <module>
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/r1v/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import (
  File "/home/user/miniconda3/envs/r1v/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in <module>
    import flash_attn_2_cuda as flash_attn_gpu
ImportError: /home/user/miniconda3/envs/r1v/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1011StorageImpl27throw_data_ptr_access_errorEv

environment:
8*H200,
cuda12.4, torch2.6.0, flash_attn2.7.4.post1

training script:

cd src/r1-v

export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
export LOG_PATH="./debug_log_2b.txt"

torchrun --nproc_per_node="8" \
    --nnodes="1" \
    --node_rank="0" \
    --master_addr="127.0.0.1" \
    --master_port="12345" \
    src/open_r1/grpo.py \
    --output_dir <OUTPUT_DIR> \
    --model_name_or_path <PATH-TO-Qwen2-VL-2B-Instruct> \ 
    --dataset_name leonardPKU/clevr_cogen_a_train \  
    --deepspeed local_scripts/zero3.json \
    --max_prompt_length 512 \
    --max_completion_length 512 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 2 \
    --logging_steps 1 \
    --bf16 \
    --report_to wandb \
    --gradient_checkpointing false \
    --attn_implementation flash_attention_2 \
    --max_pixels 401408 \
    --num_train_epochs 2 \
    --run_name Qwen2-VL-2B-GRPO-CLEVR-70k \
    --save_steps 100 \
    --save_only_model true \
    --num_generations 8   # number of outputs G in grpo, reduce it would lead to faster training and smaller memory cost but higher variance  

I have considered the possibility of conflict between the cuda version and flash attention, and tried to modify the cuda version and reconfigure the conda environment, but I still encountered the same error. I wanted to ask if anyone else has been in a similar situation.

@LiuRicky
Copy link
Contributor

the # vLLM support pip install vllm==0.7.2 in setup.sh will downgrade torch to 2.5.0, please check whether your torch had been downgrade. You can comment pip install vllm==0.7.2 if you do not use vllm. Then reinstall from conda new env

@chenyangzhu1
Copy link

the # vLLM support pip install vllm==0.7.2 in setup.sh will downgrade torch to 2.5.0, please check whether your torch had been downgrade. You can comment pip install vllm==0.7.2 if you do not use vllm. Then reinstall from conda new envsetup.sh 中的 # vLLM 支持 pip install vllm==0.7.2 会将 torch 降级到 2.5.0,请检查你的 torch 是否已经降级。如果不使用 vllm,可以注释 pip install vllm===0.7.2。然后从 conda 新环境重新安装

After installing the conda environment according to setup.sh, I will encounter such an error regardless of whether I use vllm or not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants