Skip to content

Commit ebe35a3

Browse files
authored
Revert "fix:Remove use_cache and update ReadMe. (#531)" (#532)
This reverts commit 77dd8af.
1 parent 77dd8af commit ebe35a3

File tree

2 files changed

+8
-15
lines changed

2 files changed

+8
-15
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -916,12 +916,12 @@ For information on supported dataset formats and how to tune a vision-language m
916916

917917
? May be supported, but not tested
918918

919-
Model Name & Size | Model Architecture | LoRA Tuning | Full Finetuning |
920-
-------------------- | ---------------- | --------------- | --------------- |
921-
Llama 3.2-11B Vision | MllamaForConditionalGeneration | ✅* |* |
922-
Llava 1.5-7B | LlavaForConditionalGeneration | ✅* | 🚫 |
923-
Granite 3.1-2B Vision | LlavaNextForConditionalGeneration | ✅* | 🚫 |
924-
Llava Mistral 1.6-7B | LlavaNextForConditionalGeneration | ✅* | 🚫 |
919+
Model Name & Size | Model Architecture | Full Finetuning |
920+
-------------------- | ---------------- | --------------- |
921+
Llama 3.2-11B Vision | MllamaForConditionalGeneration | ✅* |
922+
Llava 1.5-7B | LlavaForConditionalGeneration | ✅* |
923+
Granite 3.1-2B Vision | LlavaNextForConditionalGeneration | ✅* |
924+
Llava Mistral 1.6-7B | LlavaNextForConditionalGeneration | ✅* |
925925

926926
(*) - Supported with `fms-hf-tuning` v2.8.0 or later.
927927

tuning/sft_trainer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,16 +237,9 @@ def train(
237237
attn_implementation="flash_attention_2"
238238
if model_args.use_flash_attn
239239
else None,
240+
# avoid warning that use_cache is incompatible with gradient checkpointing
241+
use_cache=(not train_args.gradient_checkpointing),
240242
)
241-
try:
242-
if "use_cache" in model.language_model.config:
243-
# avoid warning that use_cache is incompatible with gradient checkpointing
244-
model.language_model.config.use_cache = (
245-
not train_args.gradient_checkpointing
246-
)
247-
except AttributeError as e:
248-
# When the model doesn't have the use_cache attribute
249-
logger.warning("Couldn't update use_cache for vision model: %s", e)
250243

251244
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path)
252245
tokenizer = processor.tokenizer

0 commit comments

Comments
 (0)