Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/models/ovis2/train.sh
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
# 28GiB
# 17GiB

pip install "transformers==4.51.*"

CUDA_VISIBLE_DEVICES=0 \
swift sft \
--model AIDC-AI/Ovis2-8B \
--dataset 'modelscope/coco_2014_caption:validation#20000' \
--model AIDC-AI/Ovis2.5-2B \
--dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#20000' \
--split_dataset_ratio 0.01 \
--train_type lora \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--attn_impl flash_attn \
--padding_free true \
--learning_rate 1e-4 \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--freeze_vit true \
--gradient_accumulation_steps 16 \
--gradient_accumulation_steps 1 \
--eval_steps 50 \
--save_steps 50 \
--save_total_limit 2 \
Expand Down
19 changes: 0 additions & 19 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,25 +140,6 @@ def _check_padding_free(self):
raise ValueError(f'The "{feature}" feature requires a flash attention implementation. '
'Please use one of: "flash_attn", "flash_attention_2", "flash_attention_3".')

if self.model_meta.is_multimodal:
supported_model_type = [
'qwen2_vl',
'qwen2_5_vl',
'qwen2_5_omni',
'qvq',
'mimo_vl',
'internvl',
'internvl_phi3',
'internvl2',
'internvl2_phi3',
'internvl2_5',
'internvl3',
]
if self.model_type not in supported_model_type:
raise ValueError(
f'Packing/padding_free is not supported for model_type `{self.model_type}`. '
f'model_type of multimodal models that support packing/padding_free: {supported_model_type}.')

def __post_init__(self) -> None:
if self.resume_from_checkpoint:
self.resume_from_checkpoint = to_abspath(self.resume_from_checkpoint, True)
Expand Down
1 change: 1 addition & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Template(ProcessorMixin):
skip_prompt = True
use_model = False
norm_bbox = 'norm1000'
support_padding_free = False # It only takes effect for multimodal models.

is_encoder_decoder = False

Expand Down
1 change: 1 addition & 0 deletions swift/llm/template/template/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class InternvlTemplate(Template):
skip_prompt = False
num_image_token = None
placeholder_tokens = ['<IMG_CONTEXT>']
support_padding_free = True

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class Qwen2VLTemplate(Template):
placeholder_tokens = ['<|image_pad|>', '<|video_pad|>']
version = 'v2'
use_model = True
support_padding_free = True

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
Expand Down Expand Up @@ -737,6 +738,7 @@ class Ovis2_5Template(ThinkingTemplate):
num_frames = 8
use_model = True
skip_prompt = False
support_padding_free = True

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
Expand Down
5 changes: 4 additions & 1 deletion swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ def _prepare_model_tokenizer(self, load_model=True):
self._prepare_generation_config()

def _prepare_template(self) -> None:
template = self.args.get_template(self.processor)
args = self.args
template = args.get_template(self.processor)
template.set_mode('train')
if template.use_model:
template.model = self.model
if args.model_meta.is_multimodal and (args.padding_free or args.packing) and not template.support_padding_free:
raise ValueError(f'Template `{args.template}` does not support padding free or packing.')
self.template = template

def _get_dataset(self):
Expand Down