@@ -383,10 +383,13 @@ def from_pretrained(
383
383
patch_loss_functions ,
384
384
post_patch_loss_function ,
385
385
)
386
- from .vision import FastBaseVisionModel
387
-
386
+ from .vision import FastBaseModel
387
+ from transformers import (
388
+ AutoModelForVision2Seq ,
389
+ AutoModelForCausalLM ,
390
+ )
388
391
389
- class FastVisionModel ( FastBaseVisionModel ):
392
+ class FastModel ( FastBaseModel ):
390
393
@staticmethod
391
394
def from_pretrained (
392
395
model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit" ,
@@ -413,7 +416,7 @@ def from_pretrained(
413
416
patch_compiling_bitsandbytes ()
414
417
if use_gradient_checkpointing == "unsloth" :
415
418
patch_unsloth_smart_gradient_checkpointing (dtype = dtype )
416
-
419
+
417
420
old_model_name = model_name
418
421
if not use_exact_model_name :
419
422
model_name = get_model_name (model_name , load_in_4bit )
@@ -427,7 +430,7 @@ def from_pretrained(
427
430
from huggingface_hub .utils import disable_progress_bars , enable_progress_bars , are_progress_bars_disabled
428
431
was_disabled = are_progress_bars_disabled ()
429
432
disable_progress_bars ()
430
-
433
+
431
434
autoconfig_error = None
432
435
peft_error = None
433
436
try :
@@ -458,7 +461,7 @@ def from_pretrained(
458
461
459
462
# Old transformers versions check
460
463
both_exist = (is_model and is_peft ) and not SUPPORTS_LLAMA32
461
-
464
+
462
465
# New transformers need to check manually.
463
466
if SUPPORTS_LLAMA32 :
464
467
# Check if folder exists locally
@@ -515,9 +518,12 @@ def from_pretrained(
515
518
if not was_disabled : enable_progress_bars ()
516
519
517
520
do_logging = os .environ .get ("UNSLOTH_ENABLE_LOGGING" , "0" ) == "1"
518
- redirector = sys .stdout if do_logging else open (os .devnull , "w" )
521
+ if do_logging :
522
+ redirector = contextlib .nullcontext ()
523
+ else :
524
+ redirector = contextlib .redirect_stdout (open (os .devnull , "w" ))
519
525
520
- with contextlib . redirect_stdout ( redirector ) :
526
+ with redirector :
521
527
patch_loss_functions (torch_compile = False )
522
528
model_types = unsloth_compile_transformers (
523
529
model_name = model_name ,
@@ -547,7 +553,6 @@ def from_pretrained(
547
553
return_logits = return_logits ,
548
554
)
549
555
pass
550
- if do_logging : redirector .close ()
551
556
552
557
# Check if this is local model since the tokenizer gets overwritten
553
558
if os .path .exists (os .path .join (old_model_name , "tokenizer_config.json" )) and \
@@ -559,7 +564,12 @@ def from_pretrained(
559
564
tokenizer_name = None
560
565
pass
561
566
562
- model , tokenizer = FastBaseVisionModel .from_pretrained (
567
+ # Check if VLM
568
+ is_vlm = (x .endswith ("ForConditionalGeneration" ) for x in model_config .architectures )
569
+ is_vlm = is_vlm or hasattr (model_config , "vision_config" )
570
+ auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
571
+
572
+ model , tokenizer = FastBaseModel .from_pretrained (
563
573
model_name = model_name ,
564
574
max_seq_length = max_seq_length ,
565
575
dtype = _get_dtype (dtype ),
@@ -570,6 +580,7 @@ def from_pretrained(
570
580
revision = revision if not is_peft else None ,
571
581
model_types = model_types ,
572
582
tokenizer_name = tokenizer_name ,
583
+ auto_model = auto_model ,
573
584
* args , ** kwargs ,
574
585
)
575
586
@@ -617,8 +628,14 @@ def from_pretrained(
617
628
trust_remote_code = trust_remote_code ,
618
629
)
619
630
# Patch it as well!
620
- model = FastBaseVisionModel .patch_peft_model (model , use_gradient_checkpointing )
631
+ model = FastBaseModel .patch_peft_model (model , use_gradient_checkpointing )
621
632
pass
622
633
return model , tokenizer
623
634
pass
624
635
pass
636
+
637
+ class FastVisionModel (FastModel ):
638
+ pass
639
+
640
+ class FastTextModel (FastModel ):
641
+ pass
0 commit comments