Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support prompt_logp_compute_kv_cache in no vllm trainer #82

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

Yukino256
Copy link

@Yukino256 Yukino256 commented Feb 13, 2025

solving this issue: #71
and the code mainly copied and modified from andyl98:grpo-vram-optimization

In my test the grpo runs faster at least 3x well without OOM in Qwen2VL-7B model:
image

And, since I have never successfully run the vllm version, I can't modify the vllm_trainer code.

my test code is:

src/open_r1/grpo.py \
--deepspeed local_scripts/zero3.json \
--output_dir="${OUTPUT_DIR}" \
--model_name_or_path="${MODEL_PATH}" \
--dataset_name="${DATA_PATH}" \
--max_prompt_length 8192 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--logging_steps 1 \
--bf16 \
--report_to wandb \
--gradient_checkpointing false \
--attn_implementation flash_attention_2 \
--max_pixels 2359296 \
--save_total_limit 8 \
--num_train_epochs 2 \
--run_name Qwen2-VL-2B-8k \
--save_steps 100 \
--save_only_model true

And you can add:
--logit_computation_mini_batch_size X
if the trl package is newest

@chenllliang
Copy link
Member

hi thanks for your contribution. can you provide more detailed running time comparison and the performance comparison on geoqa and clevr ?

@Yukino256
Copy link
Author

geoqa

Hi, thank you. I will try it as soon as possible🤗

@Yukino256
Copy link
Author

Yukino256 commented Feb 14, 2025

hi thanks for your contribution. can you provide more detailed running time comparison and the performance comparison on geoqa and clevr ?

@chenllliang
Hi, I'm sorry that I made a mistake that i used the old code where the batch decoding error is not fixed.So the speed up does not exist in fact.

Using Qwen2-VL-7B-Instruct for test:

In th GEOQA dataset, the raw code has 105s/it, and my changed code has 114s/it in 8*A800 80G.
But, the raw code got OOM error in the secode iteration:
image
And my changed code indeed runs well without OOM:
image

my test code is:

export DEBUG_MODE="true"
export LOG_PATH="./debug_log_GEOQA.txt"

OUTPUT_DIR=/grpo-result-7b-RAW-GEOQA
MODEL_PATH=/Qwen2-VL-7B-Instruct
DATA_PATH=/PKUGEOQA_R1V_Train_8K


set -x
set -e
set -u

export LANG=en_US.UTF-8
export NCCL_NET_GDR_LEVEL=2
export NCCL_IB_GID_INDEX=3
export NCCL_IB_DISABLE=0
export NCCL_IB_RETRY_CNT=7
export CUDA_LAUNCH_BLOCKING=0
export NCCL_DEBUG=info


export WANDB_BASE_URL=https://api.wandb.ai
export WANDB_PROJECT=r1-test
export WANDB_API_KEY="xxxxxx"
WANDB_RUN_NAME=GEOQA-RAW
wandb login $WANDB_API_KEY

NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l)

torchrun --nnodes=$WORLD_SIZE --nproc_per_node=$NUM_GPUS_PER_NODE --node_rank=$RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
src/open_r1/grpo.py \
--deepspeed local_scripts/zero3.json \
--output_dir="${OUTPUT_DIR}" \
--model_name_or_path="${MODEL_PATH}" \
--dataset_name="${DATA_PATH}" \
--max_prompt_length 8192 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--num_generations 8 \
--logging_steps 1 \
--bf16 \
--report_to wandb \
--gradient_checkpointing false \
--attn_implementation flash_attention_2 \
--max_pixels 2359296 \
--save_total_limit 8 \
--num_train_epochs 10 \
--run_name $WANDB_RUN_NAME \
--save_steps 100 \
--save_only_model true

@ZCMax
Copy link

ZCMax commented Feb 14, 2025

[rank0]:   File "/mnt/petrelfs/zhuchenming/R1-V/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py", line 470, in compute_loss
[rank0]:     per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
[rank0]: TypeError: Qwen2VLGRPOTrainer._get_per_token_logps() missing 2 required positional arguments: 'num_logits_to_keep' and 'mini_batch_size'

After upadting your commit, it seems an error occours.

@Yukino256
Copy link
Author

Yukino256 commented Feb 14, 2025

[rank0]:   File "/mnt/petrelfs/zhuchenming/R1-V/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py", line 470, in compute_loss
[rank0]:     per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
[rank0]: TypeError: Qwen2VLGRPOTrainer._get_per_token_logps() missing 2 required positional arguments: 'num_logits_to_keep' and 'mini_batch_size'

After upadting your commit, it seems an error occours.

hello! It seems that the code is not updated ? I changed this line into multi lines, may be the git didn't work properly?

@ZCMax
Copy link

ZCMax commented Feb 14, 2025

Sorry but I think you didn't update the code in the right place ~

@ZCMax
Copy link

ZCMax commented Feb 14, 2025

I also have another question: I set the mini_batch_size set 1, and the number of my prompt tokens could be around 2100, but it stiil occurs OOM on 8XA100 80G.

@Yukino256
Copy link
Author

I also have another question: I set the mini_batch_size set 1, and the number of my prompt tokens could be around 2100, but it stiil occurs OOM on 8XA100 80G.

Hello, I'm checking the code error. And about the OOM error, does--deepspeed local_scripts/zero3.json is added? 😭😭

@Yukino256 Yukino256 closed this Feb 14, 2025
@Yukino256 Yukino256 reopened this Feb 14, 2025
@Yukino256
Copy link
Author

@ZCMax Hello! Bugs should be fixed! I think trying 7B with zero3 is OK !🥲🥲
image

@CAOANJIA
Copy link

@ZCMax Hello! Bugs should be fixed! I think trying 7B with zero3 is OK !🥲🥲 image

请教一下,加了zero3好像多机就会卡住?

@Yukino256
Copy link
Author

Yukino256 commented Feb 19, 2025

@ZCMax Hello! Bugs should be fixed! I think trying 7B with zero3 is OK !🥲🥲 image

请教一下,加了zero3好像多机就会卡住?

Hello!目前他这个原生代码就不能用多机跑好像,我都是单机8卡跑的。他们好像目前还没实现多机?
参见#57

@rrustlee
Copy link

@Yukino256 请问您有遇到类似的报错?在utils.py的内部应该是由o3引起的,好像不会对结果造成影响,但我还是有点疑惑
image

@Yukino256
Copy link
Author

@Yukino256 请问您有遇到类似的报错?在utils.py的内部应该是由o3引起的,好像不会对结果造成影响,但我还是有点疑惑 image

这个我用他这个源代码就有,应该不是我这个代码加上去的

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants