Skip to content

Commit fe04c01

Browse files
danielhancheneverythingisc00lSethHWeidmanNinoRisteskiErland366
authoredMar 13, 2025··
Gemma 3 bug fixes (#2005)
* Update rl.py * Update rl_replacements.py * Update rl_replacements.py * llama-quantize on WINDOWS WSL error fix - edit save.py (gguf saving breaks) (#1649) * edit save.py to fix gguf saving breaks. * add check for .exe or not exe file extension for linux and windows * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * unsloth_num_chunks * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py (#1754) Fix typo in comment: know -> now. This was printed when running the Llama3.1_(8B)-GRPO.ipynb example notebook, so I'd expect others to run into it as well. * Optional logits * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * fix an import error (#1767) * fix an import error * Delete .gitignore * Update loader.py * Update save.py --------- Co-authored-by: Daniel Han <[email protected]> * SamplingParams * Convert mask to float (#1762) * [Windows Support] Add latest `xformers` wheels to pyproject.toml (#1753) * Add latest xformers * Add a couple of lines to docs * vLLMSamplingParams * Update __init__.py * default num_chunks == -1 * Versioning * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update _utils.py * Update rl_replacements.py * Update rl_replacements.py * Update pyproject.toml * Update pyproject.toml * Export Model to ollama.com (#1648) * Ollama Export Model to ollama.com Signed-off-by: Jyotin Goel <[email protected]> * Check for model_name Signed-off-by: Jyotin Goel <[email protected]> * subprocess use instead of requests | added check for ollama server Signed-off-by: Jyotin Goel <[email protected]> * create_ollama_model Signed-off-by: Jyotin Goel <[email protected]> * create_ollama_model | fix Signed-off-by: Jyotin Goel <[email protected]> * Push to Ollama Signed-off-by: Jyotin Goel <[email protected]> --------- Signed-off-by: Jyotin Goel <[email protected]> * Update cross_entropy_loss.py * torch_cuda_device * Update utils.py * Update utils.py * Update utils.py * device * device * Update loader.py * Update llama.py * Update README.md * Update llama.py * Update llama.py * Update _utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * __version__ * Update rl.py * Bug fixes * Bug fixes * Update llama.py * Update _utils.py * _wrap_fast_inference * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update _utils.py * SFT dataset prepare * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update llama.py * Update llama.py * Update utils.py * bug fix * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update __init__.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update rl.py * Update rl.py * Update rl.py * Update _utils.py * Update __init__.py * Update _utils.py * Version * versioning * Update _utils.py * Update llama.py * Update llama.py * Bug fixes * FastModel * __doc__ * Update vision.py * Update loader.py * Update loader.py * Update loader.py * version * move use_modelscope to _utils (#1938) * move use_modelscope to _utils * Update _utils.py * Update loader.py --------- Co-authored-by: Daniel Han <[email protected]> * Don't use revision when loading model_config and is_peft=True (#1949) * More syntax warnings (#1944) * move use_modelscope to _utils * fix * Update _utils.py * Update loader.py --------- Co-authored-by: Daniel Han <[email protected]> * Update loader.py * Full finetuning and other fixes * UNSLOTH_ENABLE_FULL_FINETUNING * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * full finetuning * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * max_seq_length * Update rl.py * Update rl.py * Update rl.py * Update pyproject.toml * AutoModelForImageTextToText * Update mapper.py * Update pyproject.toml * Update _utils.py * Update _utils.py * Update _utils.py * Batch samples * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update loader.py * Update vision.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update mapper.py * Update vision.py * Temporary patches * Update loader.py * model names * Gemma 3 chat template * Bug fixes * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update llama.py * Update llama.py * Update rl.py * Update chat_templates.py * Update chat_templates.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Update vision.py * Revert * Update _utils.py * forced precision * Autocast * Update vision.py * Update vision.py * Update rl.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py --------- Signed-off-by: Jyotin Goel <[email protected]> Co-authored-by: Gennadii Manzhos <[email protected]> Co-authored-by: Seth Weidman <[email protected]> Co-authored-by: Nino Risteski <[email protected]> Co-authored-by: Edd <[email protected]> Co-authored-by: Ben <[email protected]> Co-authored-by: Jyotin Goel <[email protected]> Co-authored-by: Kareem <[email protected]> Co-authored-by: Wilson Wu <[email protected]>
1 parent 71039cb commit fe04c01

File tree

7 files changed

+224
-115
lines changed

7 files changed

+224
-115
lines changed
 

‎unsloth/chat_templates.py

+92-81
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
"to_sharegpt",
2222
"standardize_sharegpt",
23+
"standardize_data_formats",
2324
"apply_chat_template",
2425
"train_on_responses_only",
2526

@@ -37,7 +38,9 @@
3738
import re
3839
from unsloth_zoo.dataset_utils import (
3940
train_on_responses_only,
41+
standardize_data_formats,
4042
)
43+
standardize_sharegpt = standardize_data_formats
4144
CHAT_TEMPLATES = {}
4245
DEFAULT_SYSTEM_MESSAGE = {}
4346

@@ -934,6 +937,84 @@
934937
pass
935938

936939

940+
# =========================================== Gemma-3
941+
# Obtained via
942+
# print(tokenizer.chat_template.replace("}\n", "####").replace("\n", "\\n").replace("####", "}\n"))
943+
gemma3_template = \
944+
"""{{ bos_token }}
945+
{%- if messages[0]['role'] == 'system' -%}
946+
{%- if messages[0]['content'] is string -%}
947+
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
948+
{%- else -%}
949+
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}
950+
{%- endif -%}
951+
{%- set loop_messages = messages[1:] -%}
952+
{%- else -%}
953+
{%- set first_user_prefix = "" -%}
954+
{%- set loop_messages = messages -%}
955+
{%- endif -%}
956+
{%- for message in loop_messages -%}
957+
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
958+
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
959+
{%- endif -%}
960+
{%- if (message['role'] == 'assistant') -%}
961+
{%- set role = "model" -%}
962+
{%- else -%}
963+
{%- set role = message['role'] -%}
964+
{%- endif -%}
965+
{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
966+
{%- if message['content'] is string -%}
967+
{{ message['content'] | trim }}
968+
{%- elif message['content'] is iterable -%}
969+
{%- for item in message['content'] -%}
970+
{%- if item['type'] == 'image' -%}
971+
{{ '<start_of_image>' }}
972+
{%- elif item['type'] == 'text' -%}
973+
{{ item['text'] | trim }}
974+
{%- endif -%}
975+
{%- endfor -%}
976+
{%- else -%}
977+
{{ raise_exception("Invalid content type") }}
978+
{%- endif -%}
979+
{{ '<end_of_turn>\n' }}
980+
{%- endfor -%}
981+
{%- if add_generation_prompt -%}
982+
{{ '<start_of_turn>model\n' }}
983+
{%- endif -%}
984+
"""
985+
986+
# Ollama from https://ollama.com/library/gemma3/blobs/e0a42594d802
987+
gemma3_ollama = \
988+
'''
989+
FROM {__FILE_LOCATION__}
990+
TEMPLATE """{{- range $i, $_ := .Messages }}
991+
{{- $last := eq (len (slice $.Messages $i)) 1 }}
992+
{{- if or (eq .Role "user") (eq .Role "system") }}<start_of_turn>user
993+
{{ .Content }}<end_of_turn>
994+
{{ if $last }}<start_of_turn>model
995+
{{ end }}
996+
{{- else if eq .Role "assistant" }}<start_of_turn>model
997+
{{ .Content }}{{ if not $last }}<end_of_turn>
998+
{{ end }}
999+
{{- end }}
1000+
{{- end }}"""
1001+
PARAMETER stop "<end_of_turn>"
1002+
PARAMETER stop "<eos>"
1003+
PARAMETER temperature 0.1
1004+
PARAMETER min_p 0.0
1005+
PARAMETER top_k 64
1006+
PARAMETER top_p 0.95
1007+
PARAMETER num_predict 32768
1008+
'''
1009+
1010+
gemma3_template_eos_token = "<end_of_turn>"
1011+
CHAT_TEMPLATES["gemma-3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)
1012+
DEFAULT_SYSTEM_MESSAGE["gemma-3"] = None # No system message in Gemma-3
1013+
1014+
CHAT_TEMPLATES["gemma3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)
1015+
DEFAULT_SYSTEM_MESSAGE["gemma3"] = None # No system message in Gemma-3
1016+
pass
1017+
9371018
def _change_system_message(template: str, type_chat_template: str, system_message: str = None):
9381019
system_message_pattern = r"\{system_message\}"
9391020

@@ -1033,11 +1114,12 @@ def get_chat_template(
10331114

10341115
# Check fast tokenizer
10351116
if not is_fast_tokenizer:
1036-
print(
1037-
"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
1038-
"Please log a Github issue if you want this as a new feature!\n"\
1039-
"Your chat template will still work, but it won't add or edit tokens."
1040-
)
1117+
pass
1118+
# print(
1119+
# "Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
1120+
# "Please log a Github issue if you want this as a new feature!\n"\
1121+
# "Your chat template will still work, but it won't add or edit tokens."
1122+
# )
10411123

10421124
elif token_mapping is not None:
10431125
# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
@@ -1396,82 +1478,6 @@ def __convert_to_sharegpt__(examples):
13961478
pass
13971479

13981480

1399-
def standardize_sharegpt(
1400-
dataset,
1401-
aliases_for_system = ["system",],
1402-
aliases_for_user = ["user", "human", "input",],
1403-
aliases_for_assistant = ["gpt", "assistant", "output",],
1404-
):
1405-
"""
1406-
Standardizes ShareGPT and other formats to user/assistant Hugging Face format.
1407-
1408-
Get aliases for the system, user and assistant roles.
1409-
These shall map to "system", "user" and "assistant" respectively.
1410-
1411-
aliases_for_system = ["system",],
1412-
aliases_for_user = ["user", "human", "input",],
1413-
aliases_for_assistant = ["gpt", "assistant", "output",],
1414-
"""
1415-
import collections
1416-
import itertools
1417-
1418-
convos = dataset[:10]["conversations"]
1419-
uniques = collections.defaultdict(list)
1420-
for convo in convos:
1421-
for message in convo:
1422-
for key, value in message.items():
1423-
uniques[key].append(value)
1424-
pass
1425-
1426-
# Must be only 2 entries
1427-
assert(len(uniques.keys()) == 2)
1428-
1429-
keys = list(uniques.keys())
1430-
length_first = len(set(uniques[keys[0]]))
1431-
length_second = len(set(uniques[keys[1]]))
1432-
1433-
if length_first < length_second:
1434-
# Role is assigned to the first element
1435-
role_key = keys[0]
1436-
content_key = keys[1]
1437-
else:
1438-
role_key = keys[1]
1439-
content_key = keys[0]
1440-
pass
1441-
1442-
# Check roles are in aliases
1443-
all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant)
1444-
roles = set(uniques[role_key])
1445-
leftover_aliases = (all_aliases | roles) - all_aliases
1446-
if len(leftover_aliases) != 0:
1447-
raise TypeError(
1448-
f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases."
1449-
)
1450-
pass
1451-
1452-
# Mapping for aliases
1453-
aliases_mapping = {}
1454-
for x in aliases_for_system: aliases_mapping[x] = "system"
1455-
for x in aliases_for_user: aliases_mapping[x] = "user"
1456-
for x in aliases_for_assistant: aliases_mapping[x] = "assistant"
1457-
1458-
def _standardize_dataset(examples):
1459-
convos = examples["conversations"]
1460-
all_convos = []
1461-
for convo in convos:
1462-
new_convo = [
1463-
{ "role" : aliases_mapping[message[role_key]], "content" : message[content_key], }
1464-
for message in convo
1465-
]
1466-
all_convos.append(new_convo)
1467-
pass
1468-
return { "conversations" : all_convos, }
1469-
pass
1470-
1471-
return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format")
1472-
pass
1473-
1474-
14751481
def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
14761482
added_tokens_decoder = tokenizer.added_tokens_decoder.values()
14771483
added_tokens_decoder = [str(x) for x in added_tokens_decoder]
@@ -1934,6 +1940,11 @@ def formatting_prompts_func(examples):
19341940
tokenizer._ollama_modelfile = modelfile
19351941
tokenizer._unsloth_input_part = input_part
19361942
tokenizer._unsloth_output_part = output_part
1943+
if hasattr(tokenizer, "tokenizer"):
1944+
tokenizer.tokenizer.chat_template = jinja_template
1945+
tokenizer.tokenizer._ollama_modelfile = modelfile
1946+
tokenizer.tokenizer._unsloth_input_part = input_part
1947+
tokenizer.tokenizer._unsloth_output_part = output_part
19371948

19381949
return dataset.map(formatting_prompts_func, batched = True,)
19391950
pass

‎unsloth/models/_utils.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from platform import system as platform_system
7272
platform_system = platform_system()
7373
import numpy as np
74+
import contextlib
7475
import warnings, subprocess, re, inspect, psutil, os, math
7576
from unsloth_zoo.utils import Version
7677

@@ -113,6 +114,11 @@
113114
from unsloth_zoo.training_utils import (
114115
prepare_model_for_training,
115116
)
117+
from unsloth_zoo.temporary_patches import (
118+
TEMPORARY_PATCHES,
119+
)
120+
for temporary_patch in TEMPORARY_PATCHES:
121+
temporary_patch()
116122

117123
# =============================================
118124
# Disable some warnings which can get annoying
@@ -981,7 +987,14 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
981987
"Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient"
982988
)
983989
pass
984-
return self._old_compute_loss(model, inputs, *args, **kwargs)
990+
991+
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0":
992+
autocaster = contextlib.nullcontext()
993+
else:
994+
autocaster = torch.autocast(device_type = "cuda", dtype = torch.float32)
995+
with autocaster:
996+
outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
997+
return outputs
985998
pass
986999

9871000

‎unsloth/models/llama.py

+33
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..tokenizer_utils import *
3939
if HAS_FLASH_ATTENTION:
4040
from flash_attn import flash_attn_func
41+
from .vision import FastBaseModel
4142

4243
# Final patching code
4344
from transformers.models.llama.modeling_llama import (
@@ -1648,6 +1649,7 @@ def from_pretrained(
16481649
disable_log_stats = False,
16491650
**kwargs,
16501651
):
1652+
os.environ["UNSLOTH_USE_NEW_MODEL"] = "0"
16511653
if trust_remote_code:
16521654
if fast_inference:
16531655
raise NotImplementedError("Unsloth: Fast inference does not support `trust_remote_code` yet.")
@@ -2016,6 +2018,31 @@ def get_peft_model(
20162018
temporary_location = "_unsloth_temporary_saved_buffers",
20172019
**kwargs,
20182020
):
2021+
if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1":
2022+
return FastBaseModel.get_peft_model(
2023+
model = model,
2024+
r = r,
2025+
target_modules = target_modules,
2026+
lora_alpha = lora_alpha,
2027+
lora_dropout = lora_dropout,
2028+
bias = bias,
2029+
finetune_vision_layers = False,
2030+
finetune_language_layers = True,
2031+
finetune_attention_modules = True,
2032+
finetune_mlp_modules = True,
2033+
layers_to_transform = layers_to_transform,
2034+
layers_pattern = layers_pattern,
2035+
use_gradient_checkpointing = use_gradient_checkpointing,
2036+
random_state = random_state,
2037+
max_seq_length = max_seq_length,
2038+
use_rslora = use_rslora,
2039+
modules_to_save = modules_to_save,
2040+
init_lora_weights = init_lora_weights,
2041+
loftq_config = loftq_config,
2042+
temporary_location = temporary_location,
2043+
**kwargs,
2044+
)
2045+
pass
20192046
if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1":
20202047
print("Unsloth: Full finetuning is enabled, so .get_peft_model has no effect")
20212048
return model
@@ -2435,6 +2462,12 @@ def patch_peft_model(
24352462
model,
24362463
use_gradient_checkpointing = True,
24372464
):
2465+
if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1":
2466+
return FastBaseModel.patch_peft_model(
2467+
model = model,
2468+
use_gradient_checkpointing = use_gradient_checkpointing,
2469+
)
2470+
pass
24382471
if not isinstance(model, PeftModelForCausalLM):
24392472
raise TypeError(
24402473
"Unsloth: Your model needs to call `.get_peft_model` first!"

‎unsloth/models/loader.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class FastLanguageModel(FastLlamaModel):
7070
@staticmethod
7171
def from_pretrained(
7272
model_name = "unsloth/Llama-3.2-1B-Instruct",
73-
max_seq_length = None,
73+
max_seq_length = 2048,
7474
dtype = None,
7575
load_in_4bit = True,
7676
load_in_8bit = False,
@@ -96,7 +96,7 @@ def from_pretrained(
9696
if load_in_8bit or full_finetuning:
9797
return FastModel.from_pretrained(
9898
model_name = model_name,
99-
max_seq_length = max_seq_length, # [TODO] No effect
99+
max_seq_length = max_seq_length,
100100
dtype = dtype,
101101
load_in_4bit = load_in_4bit,
102102
load_in_8bit = load_in_8bit,
@@ -295,7 +295,7 @@ def from_pretrained(
295295
else:
296296
return FastModel.from_pretrained(
297297
model_name = model_name,
298-
max_seq_length = max_seq_length, # [TODO] No effect
298+
max_seq_length = max_seq_length,
299299
dtype = dtype,
300300
load_in_4bit = load_in_4bit,
301301
load_in_8bit = load_in_8bit,
@@ -442,7 +442,7 @@ class FastModel(FastBaseModel):
442442
@staticmethod
443443
def from_pretrained(
444444
model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
445-
max_seq_length = None, # [TODO] No effect
445+
max_seq_length = 2048,
446446
dtype = None,
447447
load_in_4bit = True,
448448
load_in_8bit = False,
@@ -500,6 +500,8 @@ def from_pretrained(
500500
raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST)
501501
elif "aya-vision" in model_name.lower() and transformers_version < Version("4.50.0.dev0"):
502502
raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY)
503+
elif "gemma-3" in model_name.lower() and transformers_version < Version("4.50.0.dev0"):
504+
raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY)
503505
pass
504506

505507
if USE_MODELSCOPE and not os.path.exists(model_name):

‎unsloth/models/mapper.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -638,37 +638,45 @@
638638
"Qwen/QwQ-32B",
639639
"unsloth/QwQ-32B-bnb-4bit",
640640
),
641-
"unsloth/gemma-3-1b-it-bnb-4bit" : (
641+
"unsloth/gemma-3-1b-it-unsloth-bnb-4bit" : (
642642
"unsloth/gemma-3-1b-it",
643643
"google/gemma-3-1b-it",
644+
"unsloth/gemma-3-1b-it-bnb-4bit",
644645
),
645-
"unsloth/gemma-3-4b-it-bnb-4bit" : (
646+
"unsloth/gemma-3-4b-it-unsloth-bnb-4bit" : (
646647
"unsloth/gemma-3-4b-it",
647648
"google/gemma-3-4b-it",
649+
"unsloth/gemma-3-4b-it-bnb-4bit",
648650
),
649-
"unsloth/gemma-3-12b-it-bnb-4bit" : (
651+
"unsloth/gemma-3-12b-it-unsloth-bnb-4bit" : (
650652
"unsloth/gemma-3-12b-it",
651653
"google/gemma-3-12b-it",
654+
"unsloth/gemma-3-12b-it-bnb-4bit",
652655
),
653-
"unsloth/gemma-3-27b-it-bnb-4bit" : (
656+
"unsloth/gemma-3-27b-it-unsloth-bnb-4bit" : (
654657
"unsloth/gemma-3-27b-it",
655658
"google/gemma-3-27b-it",
659+
"unsloth/gemma-3-27b-it-bnb-4bit",
656660
),
657-
"unsloth/gemma-3-1b-pt-bnb-4bit" : (
661+
"unsloth/gemma-3-1b-pt-unsloth-bnb-4bit" : (
658662
"unsloth/gemma-3-1b-pt",
659663
"google/gemma-3-1b-pt",
664+
"unsloth/gemma-3-1b-pt-bnb-4bit",
660665
),
661-
"unsloth/gemma-3-4b-pt-bnb-4bit" : (
666+
"unsloth/gemma-3-4b-pt-unsloth-bnb-4bit" : (
662667
"unsloth/gemma-3-4b-pt",
663668
"google/gemma-3-4b-pt",
669+
"unsloth/gemma-3-4b-pt-bnb-4bit",
664670
),
665-
"unsloth/gemma-3-12b-pt-bnb-4bit" : (
671+
"unsloth/gemma-3-12b-pt-unsloth-bnb-4bit" : (
666672
"unsloth/gemma-3-12b-pt",
667673
"google/gemma-3-12b-pt",
674+
"unsloth/gemma-3-12b-pt-bnb-4bit",
668675
),
669-
"unsloth/gemma-3-27b-pt-bnb-4bit" : (
676+
"unsloth/gemma-3-27b-pt-unsloth-bnb-4bit" : (
670677
"unsloth/gemma-3-27b-pt",
671678
"google/gemma-3-27b-pt",
679+
"unsloth/gemma-3-27b-pt-bnb-4bit",
672680
),
673681
}
674682

‎unsloth/models/rl.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,24 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
236236
mixed_precision = \
237237
"use_bf16 = getattr(args, 'bf16', False)\n"\
238238
"use_fp16 = getattr(args, 'fp16', False)\n"\
239+
"force_float32 = False\n"\
240+
"if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':\n"\
241+
" if use_bf16 or use_fp16:\n"\
242+
" print('Unsloth: Switching to float32 training since model cannot work with float16')\n"\
243+
" force_float32 = True\n"\
239244
"mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')\n"\
240245
"dtype = getattr(model.config, 'torch_dtype', None)\n"\
241246
"if dtype is None: dtype = model.get_input_embeddings().dtype\n"\
242247
"from unsloth_zoo.utils import _get_dtype\n"\
243248
"dtype = _get_dtype(dtype)\n"\
244249
"float16 = dtype == torch.float16\n"\
245-
"if float16 and use_bf16: raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\
246-
"if not float16 and use_fp16: raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\
247-
"if (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"\
250+
"if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\
251+
"if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\
252+
"if force_float32:\n"\
253+
" args.fp16 = False\n"\
254+
" args.bf16 = False\n"\
255+
" os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"\
256+
"elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"\
248257
" args.fp16 = float16\n"\
249258
" args.bf16 = not float16\n"\
250259
" os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\n"
@@ -287,7 +296,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
287296
"bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"\
288297
"if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\
289298
"if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\
290-
"if os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"\
299+
"if force_float32:\n"\
300+
" args.bf16_full_eval = False\n"\
301+
" args.fp16_full_eval = False\n"\
302+
"elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"\
291303
" args.bf16_full_eval = True\n"\
292304
" args.fp16_full_eval = False\n"\
293305
"elif not bf16_full_eval and not fp16_full_eval:\n"\
@@ -343,11 +355,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
343355
if "data_collator" in call_args and "train_dataset" in call_args:
344356
data_collator_check = \
345357
"if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\
346-
" print('Unsloth: Changing data collator to `DataCollatorForLanguageModeling` since `labels` not found.')\n"\
347358
" data_collator = DataCollatorForLanguageModeling("\
348359
"tokenizer = processing_class if 'processing_class' in locals() else tokenizer, mlm = False)\n"\
349360
"elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\
350-
" print('Unsloth: Changing data collator to `DataCollatorForSeq2Seq` since `labels` found.')\n"\
351361
" data_collator = DataCollatorForSeq2Seq("\
352362
"tokenizer = processing_class if 'processing_class' in locals() else tokenizer)\n"
353363
extra_args += data_collator_check

‎unsloth/models/vision.py

+47-15
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,49 @@
2525
except:
2626
from transformers import AutoModelForVision2Seq
2727
pass
28-
from .llama import *
2928
from ..kernels import (
3029
post_patch_loss_function,
3130
)
3231
from ._utils import __version__
32+
from ._utils import *
33+
from ..save import patch_saving_functions
3334
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
35+
from peft import PeftModelForCausalLM
3436
from transformers import set_seed as transformers_set_seed
3537
from unsloth_zoo.peft_utils import (
3638
get_peft_regex,
3739
SKIP_QUANTIZATION_MODULES,
3840
requires_grad_for_gradient_checkpointing,
3941
)
42+
from transformers.models.llama.modeling_llama import logger
43+
from transformers import __version__ as transformers_version
4044
from triton import __version__ as triton_version
4145
from unsloth_zoo.utils import _get_dtype
4246
from unsloth_zoo.patching_utils import patch_model_and_tokenizer
4347
from unsloth_zoo.training_utils import prepare_model_for_training
4448
import types
4549
import functools
50+
import os
51+
import gc
52+
import math
53+
import functools
54+
from typing import Optional, Tuple, List, Union
55+
import re, inspect, sys
56+
import types
57+
try:
58+
from huggingface_hub.utils import get_token
59+
except:
60+
# Old HF Hub versions <= 0.0.25
61+
from huggingface_hub.utils._token import get_token
62+
pass
4663

4764
__all__ = [
4865
"FastBaseModel",
4966
]
5067

68+
global FORCE_FLOAT32
69+
FORCE_FLOAT32 = ["gemma3"]
70+
5171

5272
def unsloth_base_fast_generate(
5373
self,
@@ -86,6 +106,7 @@ def unsloth_base_fast_generate(
86106
except: pass
87107

88108
# Mixed precision autocast
109+
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32
89110
with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype):
90111
output = self._old_generate(*args, **kwargs)
91112
pass
@@ -100,7 +121,7 @@ class FastBaseModel:
100121
@staticmethod
101122
def from_pretrained(
102123
model_name = "unsloth/Llama-3.2-1B-Instruct",
103-
max_seq_length = None,
124+
max_seq_length = 2048,
104125
dtype = None,
105126
load_in_4bit = True,
106127
load_in_8bit = False,
@@ -114,6 +135,7 @@ def from_pretrained(
114135
use_gradient_checkpointing = "unsloth",
115136
**kwargs,
116137
):
138+
os.environ["UNSLOTH_USE_NEW_MODEL"] = "1"
117139
if trust_remote_code:
118140
print(
119141
"Unsloth: WARNING `trust_remote_code` is True.\n"\
@@ -129,8 +151,12 @@ def from_pretrained(
129151
try: vllm_version = f" vLLM: {importlib_version('vllm')}."
130152
except: vllm_version = ""
131153

154+
model_type_arch = model_types[0]
155+
if model_type_arch == "siglip" and len(model_types) != 1:
156+
model_type_arch = model_types[1]
157+
132158
statistics = \
133-
f"==((====))== Unsloth {__version__}: Fast {model_types[0].title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\
159+
f"==((====))== Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\
134160
f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
135161
f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\
136162
f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
@@ -156,6 +182,17 @@ def from_pretrained(
156182

157183
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
158184

185+
global FORCE_FLOAT32
186+
os.environ["UNSLOTH_FORCE_FLOAT32"] = "0"
187+
bnb_compute_dtype = dtype
188+
for disable_name in FORCE_FLOAT32:
189+
if disable_name.lower() == model_type_arch.lower() and dtype == torch.float16:
190+
print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.")
191+
os.environ["UNSLOTH_FORCE_FLOAT32"] = "1"
192+
bnb_compute_dtype = torch.float32
193+
break
194+
pass
195+
159196
bnb_config = None
160197
if full_finetuning and (load_in_4bit or load_in_8bit):
161198
print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.")
@@ -170,13 +207,13 @@ def from_pretrained(
170207
load_in_4bit = True,
171208
bnb_4bit_use_double_quant = True,
172209
bnb_4bit_quant_type = "nf4",
173-
bnb_4bit_compute_dtype = dtype,
174-
llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES,
210+
bnb_4bit_compute_dtype = bnb_compute_dtype,
211+
llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
175212
)
176213
elif load_in_8bit:
177214
bnb_config = BitsAndBytesConfig(
178215
load_in_8bit = True,
179-
llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES,
216+
llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
180217
)
181218
elif not load_in_4bit and not load_in_8bit and not full_finetuning:
182219
print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.")
@@ -185,8 +222,8 @@ def from_pretrained(
185222
load_in_4bit = True,
186223
bnb_4bit_use_double_quant = True,
187224
bnb_4bit_quant_type = "nf4",
188-
bnb_4bit_compute_dtype = dtype,
189-
llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES,
225+
bnb_4bit_compute_dtype = bnb_compute_dtype,
226+
llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
190227
)
191228
pass
192229

@@ -212,7 +249,7 @@ def from_pretrained(
212249
# quantization_config = bnb_config,
213250
token = token,
214251
trust_remote_code = trust_remote_code,
215-
# attn_implementation = "sdpa", [TODO] Pixtral for eg fails
252+
attn_implementation = "sdpa", #[TODO] Pixtral for eg fails
216253
**kwargs,
217254
)
218255
# Return old flag
@@ -408,12 +445,7 @@ def post_patch_model(
408445

409446
from transformers.trainer import Trainer
410447
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
411-
raise RuntimeError(
412-
'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\
413-
'enabling it will require much more work, so we have to prioritize. Please understand!\n'\
414-
'We do have a separate beta version, which you can contact us about!\n'\
415-
'Thank you for your understanding and we appreciate it immensely!'
416-
)
448+
raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop')
417449
pass
418450
patch_saving_functions(model, vision = True)
419451

0 commit comments

Comments
 (0)
Please sign in to comment.