Skip to content

Commit a5c4872

Browse files
authored
[train] support Ovis2.5 padding_free (#5486)
1 parent 5ff8d5b commit a5c4872

File tree

6 files changed

+16
-26
lines changed

6 files changed

+16
-26
lines changed

examples/models/ovis2/train.sh

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
1-
# 28GiB
1+
# 17GiB
22

33
pip install "transformers==4.51.*"
44

55
CUDA_VISIBLE_DEVICES=0 \
66
swift sft \
7-
--model AIDC-AI/Ovis2-8B \
8-
--dataset 'modelscope/coco_2014_caption:validation#20000' \
7+
--model AIDC-AI/Ovis2.5-2B \
8+
--dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#20000' \
99
--split_dataset_ratio 0.01 \
1010
--train_type lora \
1111
--torch_dtype bfloat16 \
1212
--num_train_epochs 1 \
13-
--per_device_train_batch_size 1 \
14-
--per_device_eval_batch_size 1 \
13+
--per_device_train_batch_size 16 \
14+
--per_device_eval_batch_size 16 \
15+
--attn_impl flash_attn \
16+
--padding_free true \
1517
--learning_rate 1e-4 \
1618
--lora_rank 8 \
1719
--lora_alpha 32 \
1820
--target_modules all-linear \
1921
--freeze_vit true \
20-
--gradient_accumulation_steps 16 \
22+
--gradient_accumulation_steps 1 \
2123
--eval_steps 50 \
2224
--save_steps 50 \
2325
--save_total_limit 2 \

swift/llm/argument/train_args.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -140,25 +140,6 @@ def _check_padding_free(self):
140140
raise ValueError(f'The "{feature}" feature requires a flash attention implementation. '
141141
'Please use one of: "flash_attn", "flash_attention_2", "flash_attention_3".')
142142

143-
if self.model_meta.is_multimodal:
144-
supported_model_type = [
145-
'qwen2_vl',
146-
'qwen2_5_vl',
147-
'qwen2_5_omni',
148-
'qvq',
149-
'mimo_vl',
150-
'internvl',
151-
'internvl_phi3',
152-
'internvl2',
153-
'internvl2_phi3',
154-
'internvl2_5',
155-
'internvl3',
156-
]
157-
if self.model_type not in supported_model_type:
158-
raise ValueError(
159-
f'Packing/padding_free is not supported for model_type `{self.model_type}`. '
160-
f'model_type of multimodal models that support packing/padding_free: {supported_model_type}.')
161-
162143
def __post_init__(self) -> None:
163144
if self.resume_from_checkpoint:
164145
self.resume_from_checkpoint = to_abspath(self.resume_from_checkpoint, True)

swift/llm/template/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class Template(ProcessorMixin):
4949
skip_prompt = True
5050
use_model = False
5151
norm_bbox = 'norm1000'
52+
support_padding_free = False # It only takes effect for multimodal models.
5253

5354
is_encoder_decoder = False
5455

swift/llm/template/template/internvl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class InternvlTemplate(Template):
2020
skip_prompt = False
2121
num_image_token = None
2222
placeholder_tokens = ['<IMG_CONTEXT>']
23+
support_padding_free = True
2324

2425
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
2526
inputs: StdTemplateInputs) -> List[Context]:

swift/llm/template/template/qwen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ class Qwen2VLTemplate(Template):
231231
placeholder_tokens = ['<|image_pad|>', '<|video_pad|>']
232232
version = 'v2'
233233
use_model = True
234+
support_padding_free = True
234235

235236
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
236237
inputs: StdTemplateInputs) -> List[Context]:
@@ -741,6 +742,7 @@ class Ovis2_5Template(ThinkingTemplate):
741742
num_frames = 8
742743
use_model = True
743744
skip_prompt = False
745+
support_padding_free = True
744746

745747
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
746748
inputs: StdTemplateInputs) -> List[Context]:

swift/llm/train/sft.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,13 @@ def _prepare_model_tokenizer(self, load_model=True):
5858
self._prepare_generation_config()
5959

6060
def _prepare_template(self) -> None:
61-
template = self.args.get_template(self.processor)
61+
args = self.args
62+
template = args.get_template(self.processor)
6263
template.set_mode('train')
6364
if template.use_model:
6465
template.model = self.model
66+
if args.model_meta.is_multimodal and (args.padding_free or args.packing) and not template.support_padding_free:
67+
raise ValueError(f'Template `{args.template}` does not support padding free or packing.')
6568
self.template = template
6669

6770
def _get_dataset(self):

0 commit comments

Comments
 (0)