Skip to content

[BUG] Error when training VLM with GKD when per_device_train_batch_size = 1 #7034

@uyzhang

Description

@uyzhang

Describe the bug
What the bug is, and how to reproduce, better with screenshots(描述bug以及复现过程,最好有截图)

When per_device_train_batch_size=1, an error occurs, but there is no problem when it is greater than 1.

PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
MASTER_PORT=29501 \
NPROC_PER_NODE=4 \
swift rlhf \
    --rlhf_type gkd \
    --model /root/cache/hub/models--Qwen--Qwen2.5-VL-3B-Instruct/snapshots/66285546d2b821cf421d4f5eb2576359d3770cd3 \
    --teacher_model /root/cache/hub/models--Qwen--Qwen2.5-VL-3B-Instruct/snapshots/66285546d2b821cf421d4f5eb2576359d3770cd3 \
    --dataset 'modelscope/coco_2014_caption:validation#2000' \
    --load_from_cache_file true \
    --split_dataset_ratio 0.01 \
    --train_type full \
    --seq_kd true \
    --torch_dtype bfloat16 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 4 \
    --learning_rate 1e-5 \
    --freeze_vit true \
    --gradient_accumulation_steps 1 \
    --eval_steps 50 \
    --save_steps 50 \
    --save_total_limit 2 \
    --deepspeed zero2 \
    --attn_impl flash_attn \
    --logging_steps 5 \
    --max_length 4096 \
    --max_completion_length 512 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4 \
    --dataset_num_proc 4 \
    --save_only_model true

Your hardware and system info
Write your system info like CUDA version/system/GPU/torch version here(在这里给出硬件信息和系统信息,如CUDA版本,系统,GPU型号和torch版本等)

CUDA: 12.9

GPU: H20

Package                       Version      Editable project location
----------------------------- ------------ ------------------------------------------------------------------
abnf                          2.2.0
absl-py                       2.3.1
accelerate                    1.12.0
addict                        2.4.0
aiofiles                      24.1.0
aiohappyeyeballs              2.6.1
aiohttp                       3.13.2
aiosignal                     1.4.0
aliyun-python-sdk-core        2.16.0
aliyun-python-sdk-kms         2.16.5
annotated-doc                 0.0.4
annotated-types               0.7.0
antlr4-python3-runtime        4.9.3
anyio                         4.11.0
attrdict                      2.0.1
attrs                         25.4.0
av                            16.0.1
backoff                       2.2.1
binpacking                    1.5.2
bitsandbytes                  0.48.2
brotli                        1.2.0
causal_conv1d                 1.5.4
certifi                       2025.11.12
cffi                          2.0.0
chardet                       5.2.0
charset-normalizer            3.4.4
cint                          1.0.0
click                         8.3.1
contourpy                     1.3.3
cpm-kernels                   1.0.11
crcmod                        1.7
cryptography                  46.0.3
cut-cross-entropy             25.1.1
cycler                        0.12.1
dacite                        1.9.2
datasets                      3.6.0
deepspeed                     0.18.2
diffusers                     0.35.2
dill                          0.3.8
diskcache                     5.6.3
distro                        1.9.0
docstring_parser              0.17.0
einops                        0.8.1
fastapi                       0.122.0
ffmpy                         1.0.0
fickling                      0.1.5
filelock                      3.20.0
fla-core                      0.4.0
flash_attn                    2.8.0.post2
flash-linear-attention        0.4.0
fonttools                     4.60.1
frozenlist                    1.8.0
fsspec                        2025.3.0
future                        1.0.0
gitdb                         4.0.12
GitPython                     3.1.45
gql                           4.0.0
gradio                        6.0.1
gradio_client                 2.0.0
graphql-core                  3.2.7
graphviz                      0.21
groovy                        0.1.2
grpcio                        1.76.0
h11                           0.16.0
hf_transfer                   0.1.9
hf-xet                        1.2.0
hjson                         3.1.0
httpcore                      1.0.9
httpx                         0.28.1
huggingface-hub               0.36.0
idna                          3.11
importlib_metadata            8.7.0
intervaltree                  3.1.0
jieba                         0.42.1
Jinja2                        3.1.6
jiter                         0.12.0
jmespath                      0.10.0
joblib                        1.5.2
json_repair                   0.54.2
jsonschema                    4.25.1
jsonschema-specifications     2025.9.1
kaitaistruct                  0.11
kiwisolver                    1.4.9
liger_kernel                  0.6.4
Markdown                      3.10
markdown-it-py                4.0.0
MarkupSafe                    3.0.3
matplotlib                    3.10.7
mdurl                         0.1.2
modelscope                    1.32.0
mpi4py                        4.1.1
mpmath                        1.3.0
ms_swift                      3.11.0.dev0
msgpack                       1.1.2
msgspec                       0.20.0
multidict                     6.7.0
multiprocess                  0.70.16
networkx                      3.6
ninja                         1.13.0
nltk                          3.9.2
numpy                         2.3.5
nvidia-cublas-cu12            12.8.4.1
nvidia-cuda-cupti-cu12        12.8.90
nvidia-cuda-nvrtc-cu12        12.8.93
nvidia-cuda-runtime-cu12      12.8.90
nvidia-cudnn-cu12             9.10.2.21
nvidia-cufft-cu12             11.3.3.83
nvidia-cufile-cu12            1.13.1.3
nvidia-curand-cu12            10.3.9.90
nvidia-cusolver-cu12          11.7.3.90
nvidia-cusparse-cu12          12.5.8.93
nvidia-cusparselt-cu12        0.7.1
nvidia-ml-py                  13.580.82
nvidia-nccl-cu12              2.27.3
nvidia-nvjitlink-cu12         12.8.93
nvidia-nvtx-cu12              12.8.90
omegaconf                     2.3.0
openai                        2.8.1
orjson                        3.11.4
oss2                          2.19.1
packaging                     25.0
pandas                        2.3.3
pdfminer.six                  20250506
peft                          0.18.0
pillow                        12.0.0
pip                           25.3
platformdirs                  4.5.0
polyfile-weave                0.5.7
propcache                     0.4.1
protobuf                      6.33.1
psutil                        7.1.3
py-cpuinfo                    9.0.0
pyarrow                       22.0.0
pycparser                     2.23
pycryptodome                  3.23.0
pydantic                      2.12.4
pydantic_core                 2.41.5
pydub                         0.25.1
Pygments                      2.19.2
pyparsing                     3.2.5
python-dateutil               2.9.0.post0
python-multipart              0.0.20
pytz                          2025.2
PyYAML                        6.0.3
qwen-vl-utils                 0.0.14
referencing                   0.37.0
regex                         2025.11.3
requests                      2.32.5
rich                          14.2.0
rouge                         1.0.1
rpds-py                       0.29.0
safehttpx                     0.1.7
safetensors                   0.7.0
scipy                         1.16.3
semantic-version              2.10.0
sentencepiece                 0.2.1
sentry-sdk                    2.46.0
setuptools                    80.9.0
shellingham                   1.5.4
shtab                         1.8.0
simplejson                    3.20.2
six                           1.17.0
smmap                         5.0.2
sniffio                       1.3.1
sortedcontainers              2.4.0
starlette                     0.50.0
stdlib-list                   0.11.1
sympy                         1.14.0
tenacity                      9.1.2
tensorboard                   2.20.0
tensorboard-data-server       0.7.2
tiktoken                      0.12.0
tokenizers                    0.22.1
tomlkit                       0.13.3
torch                         2.8.0
torchao                       0.13.0
torchaudio                    2.8.0+cu128
torchvision                   0.23.0+cu128
tqdm                          4.67.1
transformers                  4.57.1
transformers-stream-generator 0.0.5
triton                        3.4.0
trl                           0.24.0
typeguard                     4.4.4
typer                         0.20.0
typing_extensions             4.15.0
typing-inspection             0.4.2
tyro                          0.9.35
tzdata                        2025.2
unsloth                       2025.11.4
unsloth_zoo                   2025.11.5
urllib3                       2.5.0
uvicorn                       0.38.0
wandb                         0.23.0
weave                         0.52.20
Werkzeug                      3.1.3
wheel                         0.45.1
xformers                      0.0.32.post2
xxhash                        3.6.0
yarl                          1.22.0
zipp                          3.23.0
zstandard                     0.25.0

Additional context
Add any other context about the problem here(在这里补充其他信息)

Detail logs:

[rank3]: Traceback (most recent call last):
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/cli/rlhf.py", line 7, in <module>
[rank3]:     rlhf_main()
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/llm/train/rlhf.py", line 233, in rlhf_main
[rank3]:     return SwiftRLHF(args).main()
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/llm/base.py", line 49, in main
[rank3]:     result = self.run()
[rank3]:              ^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/ray/base.py", line 170, in wrapper
[rank3]:     return func(self, *args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/llm/train/sft.py", line 209, in run
[rank3]:     return self.train(trainer)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/llm/train/sft.py", line 257, in train
[rank3]:     trainer.train(trainer.args.resume_from_checkpoint)
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/trainers/mixin.py", line 821, in train
[rank3]:     res = super().train(*args, **kwargs)
[rank3]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/trainer.py", line 2325, in train
[rank3]:     return inner_training_loop(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/trainer.py", line 2674, in _inner_training_loop
[rank3]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/trainers/rlhf_trainer/utils.py", line 428, in wrapper
[rank3]:     return func(self, *args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/trainers/rlhf_trainer/gkd_trainer.py", line 334, in training_step
[rank3]:     new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
[rank3]:                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/trainers/rlhf_trainer/gkd_trainer.py", line 123, in generate_on_policy_outputs
[rank3]:     generated_outputs = model.generate(
[rank3]:                         ^^^^^^^^^^^^^^^
[rank3]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank3]:     return func(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/generation/utils.py", line 2564, in generate
[rank3]:     result = decoding_method(
[rank3]:              ^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/generation/utils.py", line 2787, in _sample
[rank3]:     outputs = model_forward(**model_inputs, return_dict=True)
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank3]:     return inner()
[rank3]:            ^^^^^^^
[rank3]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/utils/generic.py", line 918, in wrapper
[rank3]:     output = func(self, *args, **kwargs)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1476, in forward
[rank3]:     outputs = self.model(
[rank3]:               ^^^^^^^^^^^
[rank3]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1254, in forward
[rank3]:     inputs_embeds = self.get_input_embeddings()(input_ids)
[rank3]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/nn/modules/sparse.py", line 192, in forward
[rank3]:     return F.embedding(
[rank3]:            ^^^^^^^^^^^^
[rank3]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/nn/functional.py", line 2546, in embedding
[rank3]:     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: torch.AcceleratorError: CUDA error: device-side assert triggered
[rank3]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

[rank0]: Traceback (most recent call last):
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/cli/rlhf.py", line 7, in <module>
[rank0]:     rlhf_main()
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/llm/train/rlhf.py", line 233, in rlhf_main
[rank0]:     return SwiftRLHF(args).main()
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/llm/base.py", line 49, in main
[rank0]:     result = self.run()
[rank0]:              ^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/ray/base.py", line 170, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/llm/train/sft.py", line 209, in run
[rank0]:     return self.train(trainer)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/llm/train/sft.py", line 257, in train
[rank0]:     trainer.train(trainer.args.resume_from_checkpoint)
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/trainers/mixin.py", line 821, in train
[rank0]:     res = super().train(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/trainer.py", line 2325, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/trainer.py", line 2674, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/trainers/rlhf_trainer/utils.py", line 428, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/trainers/rlhf_trainer/gkd_trainer.py", line 334, in training_step
[rank0]:     new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
[rank0]:                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/ms-swift/swift/trainers/rlhf_trainer/gkd_trainer.py", line 123, in generate_on_policy_outputs
[rank0]:     generated_outputs = model.generate(
[rank0]:                         ^^^^^^^^^^^^^^^
[rank0]:   File "/jizhicfs/leoyizhang/anaconda3/envs/beelinear/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/generation/utils.py", line 2564, in generate
[rank0]:     result = decoding_method(
[rank0]:              ^^^^^^^^^^^^^^^^
[rank0]:   File "/apdcephfs_private/qy/projects/zy/BeeLinear/transformers/src/transformers/generation/utils.py", line 2829, in _sample
[rank0]:     next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: torch.AcceleratorError: CUDA error: device-side assert triggered
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

log.txt

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions