12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- __version__ = "2025.3.9 "
15
+ __version__ = "2025.3.10 "
16
16
17
17
__all__ = [
18
18
"SUPPORTS_BFLOAT16" ,
25
25
"__version__" ,
26
26
"HAS_FLASH_ATTENTION" ,
27
27
"HAS_FLASH_ATTENTION_SOFTCAPPING" ,
28
+ "USE_MODELSCOPE" ,
28
29
"platform_system" ,
29
30
"patch_tokenizer" ,
30
31
"get_statistics" ,
100
101
from unsloth_zoo .loss_utils import (
101
102
HAS_CUT_CROSS_ENTROPY ,
102
103
fused_linear_cross_entropy ,
104
+ _unsloth_get_batch_samples ,
103
105
)
104
106
from unsloth_zoo .vision_utils import (
105
107
process_vision_info ,
108
110
get_transformers_model_type ,
109
111
unsloth_compile_transformers as _unsloth_compile_transformers ,
110
112
)
113
+ from unsloth_zoo .training_utils import (
114
+ prepare_model_for_training ,
115
+ )
111
116
112
117
# =============================================
113
118
# Disable some warnings which can get annoying
@@ -508,67 +513,16 @@ def prepare_model_for_kbit_training(
508
513
use_gradient_checkpointing : Optional = True ,
509
514
use_reentrant : Optional [bool ] = True ,
510
515
) -> Any :
511
- """
512
- Calculates where to place the gradient checkpoints given n_layers.
513
- We also freeze all other layers's gradients
514
-
515
- Args:
516
- model: Any LlamaModel with layers.
517
- use_gradient_checkpointing (`bool`, *optional*):
518
- Default enabled. Provides memory savings by not saving all activations,
519
- but only some.
520
- use_reentrant (`bool`, *optional*):
521
- https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
522
- Optimal gradient checkpointing algorithm which will be the default in
523
- future Pytorch versions.
524
- """
525
-
526
- # Freeze all parameters except LoRA
527
- with torch .no_grad ():
528
- for name , param in model .named_parameters ():
529
- if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name :
530
- param .requires_grad_ (True )
531
- # Also must be in float32!
532
- if param .dtype != torch .float32 :
533
- name = name .replace ("base_model" , "model" , 1 )
534
- layer_number = re .search (r"\.[\d]{1,}\." , name ).group (0 )
535
- name = name .replace (layer_number , f"[{ layer_number [1 :- 1 ]} ]." )
536
- name = name .replace (".weight" , "" , 1 )
537
- exec (f"{ name } .to(torch.float32)" )
538
- pass
539
- else :
540
- param .requires_grad_ (False )
541
- pass
542
- pass
543
-
544
- # Gradient checkpointing!
545
- if use_gradient_checkpointing == "unsloth" :
546
-
547
- # Saves VRAM!
548
- original_model = model
549
- while hasattr (original_model , "model" ):
550
- original_model ._offloaded_gradient_checkpointing = True
551
- original_model = original_model .model
552
- pass
553
- original_model ._offloaded_gradient_checkpointing = True
554
-
555
- model .gradient_checkpointing_enable ()
556
-
557
- elif use_gradient_checkpointing == True :
558
- model .gradient_checkpointing_enable ()
559
- pass
560
-
561
- # If use_reentrant = True which is the Pytorch default, we just make the input requires_grad.
562
- if use_reentrant :
563
- if hasattr (model , "enable_input_require_grads" ):
564
- model .enable_input_require_grads ()
565
- else :
566
- def make_inputs_require_grad (module , input , output ):
567
- output .requires_grad_ (True )
568
- model .get_input_embeddings ().register_forward_hook (make_inputs_require_grad )
569
- pass
570
-
571
- return model
516
+ return prepare_model_for_training (
517
+ model = model ,
518
+ use_gradient_checkpointing = use_gradient_checkpointing ,
519
+ use_reentrant = use_reentrant ,
520
+ full_finetuning = False ,
521
+ train_layernorms = False ,
522
+ train_embedding = False ,
523
+ train_lm_head = False ,
524
+ float32_mixed_precision = True ,
525
+ )
572
526
pass
573
527
574
528
# =============================================
@@ -999,44 +953,6 @@ def test_mask_creation():
999
953
pass
1000
954
1001
955
1002
- def _unsloth_get_batch_samples (self , epoch_iterator , num_batches ):
1003
- batch_samples = []
1004
- num_items_in_batch = None
1005
-
1006
- # Check if model allows **kwargs
1007
- model = self .model
1008
- f = model .base_model .model .forward if hasattr (model , "base_model" ) else model .forward
1009
- has_kwargs = tuple (inspect .signature (f ).parameters .values ())[- 1 ].kind == inspect ._VAR_KEYWORD
1010
-
1011
- # Iterate to find all batches
1012
- for _ in range (num_batches ):
1013
- try :
1014
- batch_samples += [next (epoch_iterator )]
1015
- except StopIteration :
1016
- break
1017
- pass
1018
-
1019
- # Get num_items_in_batch
1020
- if has_kwargs and len (batch_samples ) > 0 and "labels" in batch_samples [0 ]:
1021
- try :
1022
- num_items_in_batch = sum (
1023
- [(x ["labels" ][..., 1 :] != - 100 ).sum () for x in batch_samples ]
1024
- )
1025
-
1026
- if self .args .average_tokens_across_devices :
1027
- num_items_in_batch = self .accelerator .gather (num_items_in_batch ).sum ().item ()
1028
-
1029
- if torch .is_tensor (num_items_in_batch ):
1030
- num_items_in_batch = num_items_in_batch .item ()
1031
-
1032
- except Exception as exception :
1033
- logger .warning_once (exception )
1034
- pass
1035
-
1036
- return batch_samples , num_items_in_batch
1037
- pass
1038
-
1039
-
1040
956
def _unsloth_pre_compute_loss (self , model , inputs , * args , ** kwargs ):
1041
957
num_items_in_batch = None
1042
958
@@ -1053,7 +969,12 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
1053
969
# Get gradient accumulation steps if possible
1054
970
if num_items_in_batch is None and \
1055
971
getattr (getattr (self , "args" , self ), "gradient_accumulation_steps" , 1 ) != 1 :
1056
- name = (model .base_model .model if hasattr (model , "base_model" ) else model ).__class__ .__name__
972
+
973
+ inner_model = model
974
+ if hasattr (inner_model , "base_model" ): inner_model = inner_model . base_model
975
+ if hasattr (inner_model , "model" ): inner_model = inner_model .model
976
+ name = inner_model .__class__ .__name__
977
+
1057
978
logger .warning_once (
1058
979
f"Unsloth: Not an error, but { name } does not accept `num_items_in_batch`.\n " \
1059
980
"Using gradient accumulation will be very slightly less accurate.\n " \
@@ -1271,3 +1192,10 @@ def __str__ (self): return LOGITS_ERROR_STRING
1271
1192
try : exec (f"EMPTY_LOGITS.{ function } = raise_{ j } " , globals (), locals ())
1272
1193
except : continue
1273
1194
pass
1195
+
1196
+ USE_MODELSCOPE = os .environ .get ("UNSLOTH_USE_MODELSCOPE" , "0" ) == "1"
1197
+ if USE_MODELSCOPE :
1198
+ if importlib .util .find_spec ("modelscope" ) is None :
1199
+ raise ImportError (f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`' )
1200
+ pass
1201
+ pass
0 commit comments