We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Met below error when full-param DPO on Mistral-Nemo with FlashAttention and Zero3
3%|▎ | 20/700 [16:42<9:24:21, 49.80s/it] 3%|▎ | 21/700 [17:32<9:22:23, 49.70s/it] 3%|▎ | 22/700 [18:21<9:20:25, 49.60s/it] 3%|▎ | 23/700 [19:11<9:19:28, 49.58s/it][rank3]: Traceback (most recent call last): [rank3]: File "/scratch/azureml/cr/j/4bcb21f80f454bf083c17d28836f301c/exe/wd/./src/train.py", line 28, in <module> [rank3]: main() [rank3]: File "/scratch/azureml/cr/j/4bcb21f80f454bf083c17d28836f301c/exe/wd/./src/train.py", line 19, in main [rank3]: run_exp() [rank3]: File "/scratch/azureml/cr/j/4bcb21f80f454bf083c17d28836f301c/exe/wd/src/llamafactory/train/tuner.py", line 56, in run_exp [rank3]: run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) [rank3]: File "/scratch/azureml/cr/j/4bcb21f80f454bf083c17d28836f301c/exe/wd/src/llamafactory/train/dpo/workflow.py", line 89, in run_dpo [rank3]: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/trainer.py", line 2171, in train [rank3]: return inner_training_loop( [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/trainer.py", line 2531, in _inner_training_loop [rank3]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch) [rank3]: File "/scratch/azureml/cr/j/4bcb21f80f454bf083c17d28836f301c/exe/wd/src/llamafactory/train/dpo/trainer.py", line 356, in training_step [rank3]: return super().training_step(model, inputs, *args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/trainer.py", line 3675, in training_step [rank3]: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) [rank3]: File "/scratch/azureml/cr/j/4bcb21f80f454bf083c17d28836f301c/exe/wd/src/llamafactory/train/dpo/trainer.py", line 311, in compute_loss [rank3]: loss = super().compute_loss(model, inputs, return_outputs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1408, in compute_loss [rank3]: loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") [rank3]: File "/scratch/azureml/cr/j/4bcb21f80f454bf083c17d28836f301c/exe/wd/src/llamafactory/train/dpo/trainer.py", line 264, in get_batch_loss_metrics [rank3]: ) = self.concatenated_forward(model, batch) [rank3]: File "/scratch/azureml/cr/j/4bcb21f80f454bf083c17d28836f301c/exe/wd/src/llamafactory/train/dpo/trainer.py", line 214, in concatenated_forward [rank3]: all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl [rank3]: return forward_call(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn [rank3]: ret_val = func(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1914, in forward [rank3]: loss = self.module(*inputs, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank3]: result = forward_call(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 832, in forward [rank3]: outputs = self.model( [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank3]: result = forward_call(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 549, in forward [rank3]: layer_outputs = self._gradient_checkpointing_func( [rank3]: File "/scratch/azureml/cr/j/4bcb21f80f454bf083c17d28836f301c/exe/wd/src/llamafactory/model/model_utils/checkpointing.py", line 93, in custom_gradient_checkpointing_func [rank3]: return gradient_checkpointing_func(func, *args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner [rank3]: return disable_fn(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn [rank3]: return fn(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint [rank3]: return CheckpointFunction.apply(function, preserve, *args) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply [rank3]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward [rank3]: outputs = run_function(*args) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank3]: result = forward_call(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 246, in forward [rank3]: hidden_states, self_attn_weights = self.self_attn( [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl [rank3]: result = forward_call(*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 183, in forward [rank3]: attn_output, attn_weights = attention_interface( [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/integrations/flash_attention.py", line 50, in flash_attention_forward [rank3]: attn_output = _flash_attention_forward( [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 311, in _flash_attention_forward [rank3]: attn_output_unpad = flash_attn_varlen_func( [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 1448, in flash_attn_varlen_func [rank3]: return FlashAttnVarlenFunc.apply( [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply [rank3]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 930, in forward [rank3]: out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_ops.py", line 1061, in __call__ [rank3]: return self_._op(*args, **(kwargs or {})) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_library/autograd.py", line 98, in autograd_impl [rank3]: result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined] [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply [rank3]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_library/autograd.py", line 40, in forward [rank3]: result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_ops.py", line 672, in redispatch [rank3]: return self_._handle.redispatch_boxed(keyset, *args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 236, in backend_impl [rank3]: result = self._backend_fns[device_type](*args, **kwargs) [rank3]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 170, in _flash_attn_varlen_forward [rank3]: out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( [rank3]: RuntimeError: shape '[2, 8, 4, 128]' is invalid for input of size 4096
My command is
--deepspeed ./config/ds_z3_bf16_dpo.json --stage dpo --do_train --do_eval --dataset DPO_v1.2 --val_size 500 --template xverse --finetuning_type full --overwrite_cache --lr_scheduler_type cosine --logging_steps 10 --learning_rate 5e-07 --plot_loss --bf16 --cutoff_len 20480 --per_device_train_batch_size 1 --gradient_accumulation_steps 8 --per_device_eval_batch_size 1 --max_grad_norm 0.3 --warmup_ratio 0.05 --max_steps 700 --evaluation_strategy steps --eval_steps 35 --save_steps 35 --save_total_limit 20 --date_time 03-10 --overwrite_output_dir --flash_attn fa2 --ddp_timeout 720000 --preprocessing_num_workers 32 --log_level info --log_on_each_node false --report_to wandb --metric_for_best_model rewards/accuracies --greater_is_better True --pref_beta 0.4 --pref_ftx 0.05 --sequence_parallel_size 4 --save_only_model True
No response
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Reminder
System Info
Met below error when full-param DPO on Mistral-Nemo with FlashAttention and Zero3
Reproduction
My command is
Expected behavior
No response
Others
No response
The text was updated successfully, but these errors were encountered: