Skip to content

Commit 029461a

Browse files
danielhanchenNinoRisteskiErland366versipellisgjyotin305
authored
Gemma 3, bug fixes (#2014)
* Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * fix an import error (#1767) * fix an import error * Delete .gitignore * Update loader.py * Update save.py --------- Co-authored-by: Daniel Han <[email protected]> * SamplingParams * Convert mask to float (#1762) * [Windows Support] Add latest `xformers` wheels to pyproject.toml (#1753) * Add latest xformers * Add a couple of lines to docs * vLLMSamplingParams * Update __init__.py * default num_chunks == -1 * Versioning * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update _utils.py * Update rl_replacements.py * Update rl_replacements.py * Update pyproject.toml * Update pyproject.toml * Export Model to ollama.com (#1648) * Ollama Export Model to ollama.com Signed-off-by: Jyotin Goel <[email protected]> * Check for model_name Signed-off-by: Jyotin Goel <[email protected]> * subprocess use instead of requests | added check for ollama server Signed-off-by: Jyotin Goel <[email protected]> * create_ollama_model Signed-off-by: Jyotin Goel <[email protected]> * create_ollama_model | fix Signed-off-by: Jyotin Goel <[email protected]> * Push to Ollama Signed-off-by: Jyotin Goel <[email protected]> --------- Signed-off-by: Jyotin Goel <[email protected]> * Update cross_entropy_loss.py * torch_cuda_device * Update utils.py * Update utils.py * Update utils.py * device * device * Update loader.py * Update llama.py * Update README.md * Update llama.py * Update llama.py * Update _utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * __version__ * Update rl.py * Bug fixes * Bug fixes * Update llama.py * Update _utils.py * _wrap_fast_inference * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update _utils.py * SFT dataset prepare * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update llama.py * Update llama.py * Update utils.py * bug fix * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update __init__.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update rl.py * Update rl.py * Update rl.py * Update _utils.py * Update __init__.py * Update _utils.py * Version * versioning * Update _utils.py * Update llama.py * Update llama.py * Bug fixes * FastModel * __doc__ * Update vision.py * Update loader.py * Update loader.py * Update loader.py * version * move use_modelscope to _utils (#1938) * move use_modelscope to _utils * Update _utils.py * Update loader.py --------- Co-authored-by: Daniel Han <[email protected]> * Don't use revision when loading model_config and is_peft=True (#1949) * More syntax warnings (#1944) * move use_modelscope to _utils * fix * Update _utils.py * Update loader.py --------- Co-authored-by: Daniel Han <[email protected]> * Update loader.py * Full finetuning and other fixes * UNSLOTH_ENABLE_FULL_FINETUNING * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * full finetuning * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * max_seq_length * Update rl.py * Update rl.py * Update rl.py * Update pyproject.toml * AutoModelForImageTextToText * Update mapper.py * Update pyproject.toml * Update _utils.py * Update _utils.py * Update _utils.py * Batch samples * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update loader.py * Update vision.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update mapper.py * Update vision.py * Temporary patches * Update loader.py * model names * Gemma 3 chat template * Bug fixes * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update llama.py * Update llama.py * Update rl.py * Update chat_templates.py * Update chat_templates.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Update vision.py * Revert * Update _utils.py * forced precision * Autocast * Update vision.py * Update vision.py * Update rl.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update rl.py * vLLM fixes * constexpr * Update vision.py * Update vision.py * Update vision.py * Update rl.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update save.py * New models * Triton windows update (#1976) * Update pyproject.toml * Update README.md * Update RMS LayerNorm implementation, and list compr. change in chat templates (#1974) * Update RMS LayerNorm implementation with optimizations and testing suite * perf: optimize list comprehension in get_ollama_eos_tokens * Update Zoo * Update llama.py * Update llama.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update rl_replacements.py * Update vision.py * grpo fix * Update rl_replacements.py * Update vision.py * Update rl_replacements.py * Update vision.py * Update mapper.py * Update vision.py * Update vision.py * Update loader.py --------- Signed-off-by: Jyotin Goel <[email protected]> Co-authored-by: Nino Risteski <[email protected]> Co-authored-by: Edd <[email protected]> Co-authored-by: Ben <[email protected]> Co-authored-by: Jyotin Goel <[email protected]> Co-authored-by: Kareem <[email protected]> Co-authored-by: Wilson Wu <[email protected]> Co-authored-by: Akshay Behl <[email protected]>
1 parent fe04c01 commit 029461a

15 files changed

+243
-65
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ See [here](https://github.com/unslothai/unsloth/edit/main/README.md#advanced-pip
115115
7. **Install Unsloth:**
116116

117117
```python
118-
pip install "unsloth[windows] @ git+https://github.com/unslothai/unsloth.git"
118+
pip install unsloth
119119
```
120120

121121
#### Notes

pyproject.toml

+2-5
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,11 @@ exclude = ["images*"]
3333

3434
[project.optional-dependencies]
3535
triton = [
36-
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
37-
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
38-
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
39-
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'"
36+
"triton-windows ; platform_system == 'Windows'",
4037
]
4138

4239
huggingface = [
43-
"unsloth_zoo>=2025.3.9",
40+
"unsloth_zoo>=2025.3.11",
4441
"packaging",
4542
"tyro",
4643
"transformers>=4.46.1,!=4.47.0",

unsloth/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
198198
# Check for unsloth_zoo
199199
try:
200200
unsloth_zoo_version = importlib_version("unsloth_zoo")
201-
if Version(unsloth_zoo_version) < Version("2025.3.9"):
201+
if Version(unsloth_zoo_version) < Version("2025.3.11"):
202202
print(
203203
"Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\
204204
"To disable this, set os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'"

unsloth/chat_templates.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1512,10 +1512,7 @@ def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
15121512

15131513
# Remove duplicates
15141514
splitted = joined_text.split("\x01\x00")
1515-
final_eos_tokens = []
1516-
for old, new in zip(added_tokens_decoder, splitted):
1517-
if old == new: final_eos_tokens.append(old)
1518-
pass
1515+
final_eos_tokens = [old for old, new in zip(added_tokens_decoder, splitted) if old == new]
15191516
final_eos_tokens += extra_eos_tokens
15201517
final_eos_tokens += repeatted_tokens
15211518

unsloth/kernels/cross_entropy_loss.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ def _cross_entropy_forward(
3737
loss_ptr ,
3838
logsumexp_ptr ,
3939
labels_ptr ,
40-
VOCAB_SIZE ,
40+
VOCAB_SIZE : tl.constexpr,
4141
BLOCK_SIZE : tl.constexpr,
42-
DO_SOFTCAPPING ,
43-
SOFTCAP ,
44-
DO_LOGIT_SCALING ,
45-
LOGIT_SCALE ,
42+
DO_SOFTCAPPING : tl.constexpr,
43+
SOFTCAP : tl.constexpr,
44+
DO_LOGIT_SCALING : tl.constexpr,
45+
LOGIT_SCALE : tl.constexpr,
4646
):
4747
"""
4848
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
@@ -111,13 +111,13 @@ def _chunked_cross_entropy_forward(
111111
loss_ptr ,
112112
logsumexp_ptr ,
113113
labels_ptr ,
114-
VOCAB_SIZE ,
115-
N_CHUNKS ,
114+
VOCAB_SIZE : tl.constexpr,
115+
N_CHUNKS : tl.constexpr,
116116
BLOCK_SIZE : tl.constexpr,
117-
DO_SOFTCAPPING ,
118-
SOFTCAP ,
119-
DO_LOGIT_SCALING ,
120-
LOGIT_SCALE ,
117+
DO_SOFTCAPPING : tl.constexpr,
118+
SOFTCAP : tl.constexpr,
119+
DO_LOGIT_SCALING : tl.constexpr,
120+
LOGIT_SCALE : tl.constexpr,
121121
):
122122
"""
123123
256K vocab divided in 4 chunks
@@ -196,12 +196,12 @@ def _cross_entropy_backward(
196196
dloss_row_stride ,
197197
logsumexp_ptr ,
198198
labels_ptr ,
199-
VOCAB_SIZE ,
199+
VOCAB_SIZE : tl.constexpr,
200200
BLOCK_SIZE : tl.constexpr,
201-
DO_SOFTCAPPING ,
202-
SOFTCAP ,
203-
DO_LOGIT_SCALING ,
204-
LOGIT_SCALE ,
201+
DO_SOFTCAPPING : tl.constexpr,
202+
SOFTCAP : tl.constexpr,
203+
DO_LOGIT_SCALING : tl.constexpr,
204+
LOGIT_SCALE : tl.constexpr,
205205
):
206206
"""
207207
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)

unsloth/kernels/layernorm.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def layernorm_forward(
3030
b,
3131
r,
3232
mu,
33-
n_cols, eps,
33+
n_cols : tl.constexpr,
34+
eps : tl.constexpr,
3435
BLOCK_SIZE : tl.constexpr
3536
):
3637
row_idx = tl.program_id(0)
@@ -68,7 +69,8 @@ def layernorm_backward(
6869
b,
6970
r,
7071
mu,
71-
n_cols, eps,
72+
n_cols : tl.constexpr,
73+
eps : tl.constexpr,
7274
BLOCK_SIZE : tl.constexpr
7375
):
7476
# Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md

unsloth/kernels/rms_layernorm.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def _rms_layernorm_forward(
2222
Y, Y_row_stride,
2323
X, X_row_stride,
2424
W, W_row_stride,
25-
r, r_row_stride,
26-
n_cols, eps,
27-
BLOCK_SIZE : tl.constexpr
25+
r, r_row_stride : tl.constexpr,
26+
n_cols : tl.constexpr,
27+
eps : tl.constexpr,
28+
BLOCK_SIZE : tl.constexpr,
2829
):
2930
"""
3031
Fast RMS Layernorm kernel
@@ -57,9 +58,10 @@ def _rms_layernorm_backward(
5758
dX, dX_row_stride,
5859
X, X_row_stride,
5960
W, W_row_stride,
60-
r, r_row_stride,
61+
r, r_row_stride : tl.constexpr,
6162
# dW, dW_row_stride,
62-
n_cols, eps,
63+
n_cols : tl.constexpr,
64+
eps : tl.constexpr,
6365
GEMMA : tl.constexpr,
6466
BLOCK_SIZE : tl.constexpr,
6567
):
@@ -107,8 +109,9 @@ def _gemma_rms_layernorm_forward(
107109
Y, Y_row_stride,
108110
X, X_row_stride,
109111
W, W_row_stride,
110-
r, r_row_stride,
111-
n_cols, eps,
112+
r, r_row_stride : tl.constexpr,
113+
n_cols : tl.constexpr,
114+
eps : tl.constexpr,
112115
BLOCK_SIZE : tl.constexpr,
113116
):
114117
# Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
@@ -253,7 +256,6 @@ def unpatch_rms_layernorm():
253256
except:
254257
pass
255258
return
256-
return
257259
pass
258260

259261

unsloth/models/_utils.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "2025.3.10"
15+
__version__ = "2025.3.11"
1616

1717
__all__ = [
1818
"SUPPORTS_BFLOAT16",
@@ -72,6 +72,7 @@
7272
platform_system = platform_system()
7373
import numpy as np
7474
import contextlib
75+
import re
7576
import warnings, subprocess, re, inspect, psutil, os, math
7677
from unsloth_zoo.utils import Version
7778

@@ -181,6 +182,34 @@ def filter(self, x): return not (self.text in x.getMessage())
181182
except:
182183
pass
183184

185+
# Patch get_model_param_count to record correct 4bit / 8bit
186+
from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled
187+
def get_model_param_count(model, trainable_only = False):
188+
"""
189+
Calculate model's total param count. If trainable_only is True then count only those requiring grads
190+
"""
191+
if is_deepspeed_zero3_enabled():
192+
def numel(p):
193+
return p.ds_numel if hasattr(p, "ds_numel") else p.numel()
194+
else:
195+
def numel(p):
196+
return p.numel()
197+
s = sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
198+
if (not trainable_only) and \
199+
hasattr(model, "config") and \
200+
hasattr(model.config, "quantization_config"):
201+
202+
billions = re.findall(r"([0-9]{1,})(?:b|B)", model.config.name_or_path)
203+
if len(billions) != 0:
204+
billions = int(billions[0])
205+
s = 1_000_000_000 * billions
206+
pass
207+
return s
208+
pass
209+
import transformers.trainer_pt_utils
210+
transformers.trainer_pt_utils.get_model_param_count = get_model_param_count
211+
import transformers.trainer
212+
transformers.trainer.get_model_param_count = get_model_param_count
184213
# =============================================
185214

186215
# =============================================

unsloth/models/llama.py

+20
Original file line numberDiff line numberDiff line change
@@ -1663,6 +1663,10 @@ def from_pretrained(
16631663
if platform.system().lower() == 'windows':
16641664
print("Unsloth: vLLM does not work in Windows! Will use Unsloth inference!")
16651665
fast_inference = False
1666+
major_version, minor_version = torch.cuda.get_device_capability()
1667+
if major_version < 7:
1668+
print("Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!")
1669+
fast_inference = False
16661670
pass
16671671

16681672
if token is None: token = get_token()
@@ -1786,6 +1790,8 @@ def from_pretrained(
17861790
attn_implementation = "eager",
17871791
**kwargs,
17881792
)
1793+
model.fast_generate = model.generate
1794+
model.fast_generate_batches = None
17891795
else:
17901796
from unsloth_zoo.vllm_utils import (
17911797
load_vllm,
@@ -1804,6 +1810,7 @@ def from_pretrained(
18041810
enable_lora = True,
18051811
max_lora_rank = max_lora_rank,
18061812
disable_log_stats = disable_log_stats,
1813+
use_bitsandbytes = load_in_4bit,
18071814
)
18081815
for allowed_arg in allowed_args:
18091816
if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
@@ -2651,6 +2658,19 @@ def patch_peft_model(
26512658
torch.cuda.empty_cache()
26522659
pass
26532660

2661+
# Patch for fast inference
2662+
vllm_engine = getattr(model.model, "vllm_engine", None)
2663+
if vllm_engine is not None:
2664+
model.vllm_engine = model.model.vllm_engine
2665+
model.fast_generate = model.model.fast_generate
2666+
model.fast_generate_batches = model.model.fast_generate_batches
2667+
2668+
# Also saving and loading LoRA
2669+
from unsloth_zoo.vllm_utils import save_lora, load_lora
2670+
model.save_lora = functools.partial(save_lora, model)
2671+
model.load_lora = functools.partial(load_lora, model)
2672+
pass
2673+
26542674
# Add for_inference and for_training
26552675
model.for_training = functools.partial(FastLlamaModel.for_training, model)
26562676
model.for_inference = functools.partial(FastLlamaModel.for_inference, model)

unsloth/models/loader.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,6 @@ def from_pretrained(
405405
if is_peft:
406406
# From https://github.com/huggingface/peft/issues/184
407407
# Now add PEFT adapters
408-
model.enable_input_require_grads()
409408
model = PeftModel.from_pretrained(
410409
model,
411410
old_model_name,
@@ -498,10 +497,22 @@ def from_pretrained(
498497
raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST)
499498
elif "qwen2.5" in model_name.lower() and transformers_version < Version("4.49.0"):
500499
raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST)
501-
elif "aya-vision" in model_name.lower() and transformers_version < Version("4.50.0.dev0"):
502-
raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY)
500+
elif "aya-vision" in model_name.lower():
501+
# Disable compiling for now - errors out!
502+
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
503+
if transformers_version < Version("4.50.0.dev0"):
504+
raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY)
503505
elif "gemma-3" in model_name.lower() and transformers_version < Version("4.50.0.dev0"):
504506
raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY)
507+
elif "c4ai-command-a-03-2025" in model_name.lower() and transformers_version < Version("4.50.0.dev0"):
508+
raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY)
509+
elif "granite-vision" in model_name.lower():
510+
# Disable compiling for now - errors out!
511+
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
512+
if transformers_version < Version("4.50.0.dev0"):
513+
raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY)
514+
elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"):
515+
raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY)
505516
pass
506517

507518
if USE_MODELSCOPE and not os.path.exists(model_name):
@@ -668,7 +679,7 @@ def from_pretrained(
668679
use_gradient_checkpointing = use_gradient_checkpointing,
669680
*args, **kwargs,
670681
)
671-
682+
672683
if resize_model_vocab is not None:
673684
model.resize_token_embeddings(resize_model_vocab)
674685
pass
@@ -703,7 +714,6 @@ def from_pretrained(
703714
if is_peft:
704715
# From https://github.com/huggingface/peft/issues/184
705716
# Now add PEFT adapters
706-
model.enable_input_require_grads()
707717
model = PeftModel.from_pretrained(
708718
model,
709719
old_model_name,

unsloth/models/mapper.py

+40
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@
6262
"unsloth/llama-2-7b-chat",
6363
"meta-llama/Llama-2-7b-chat-hf",
6464
),
65+
"unsloth/Mixtral-8x7B-v0.1-unsloth-bnb-4bit" : (
66+
"unsloth/Mixtral-8x7B-v0.1",
67+
"mistralai/Mixtral-8x7B-v0.1",
68+
"unsloth/Mixtral-8x7B-v0.1-bnb-4bit",
69+
),
70+
"unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit" : (
71+
"unsloth/Mixtral-8x7B-Instruct-v0.1",
72+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
73+
"unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit",
74+
),
6575
"unsloth/codellama-7b-bnb-4bit" : (
6676
"unsloth/codellama-7b",
6777
"codellama/CodeLlama-7b-hf",
@@ -678,6 +688,36 @@
678688
"google/gemma-3-27b-pt",
679689
"unsloth/gemma-3-27b-pt-bnb-4bit",
680690
),
691+
"unsloth/reka-flash-3-unsloth-bnb-4bit" : (
692+
"unsloth/reka-flash-3",
693+
"RekaAI/reka-flash-3",
694+
"unsloth/reka-flash-3-bnb-4bit",
695+
),
696+
"unsloth/c4ai-command-a-03-2025-unsloth-bnb-4bit" : (
697+
"unsloth/c4ai-command-a-03-2025",
698+
"CohereForAI/c4ai-command-a-03-2025",
699+
"unsloth/c4ai-command-a-03-2025-bnb-4bit",
700+
),
701+
"unsloth/aya-vision-32b-unsloth-bnb-4bit" : (
702+
"unsloth/aya-vision-32b",
703+
"CohereForAI/aya-vision-32b",
704+
"unsloth/aya-vision-32b-bnb-4bit",
705+
),
706+
"unsloth/aya-vision-8b-unsloth-bnb-4bit" : (
707+
"unsloth/aya-vision-8b",
708+
"CohereForAI/aya-vision-8b",
709+
"unsloth/aya-vision-8b-bnb-4bit",
710+
),
711+
"unsloth/granite-vision-3.2-2b-unsloth-bnb-4bit" : (
712+
"unsloth/granite-vision-3.2-2b",
713+
"ibm-granite/granite-vision-3.2-2b",
714+
"unsloth/granite-vision-3.2-2b-bnb-4bit",
715+
),
716+
"unsloth/OLMo-2-0325-32B-Instruct-unsloth-bnb-4bit" : (
717+
"unsloth/OLMo-2-0325-32B-Instruct",
718+
"allenai/OLMo-2-0325-32B-Instruct",
719+
"unsloth/OLMo-2-0325-32B-Instruct-bnb-4bit",
720+
),
681721
}
682722

683723
INT_TO_FLOAT_MAPPER = {}

0 commit comments

Comments
 (0)