diff --git a/pretraining.py b/pretraining.py index 95544db..c6523d6 100644 --- a/pretraining.py +++ b/pretraining.py @@ -370,8 +370,10 @@ def main(): else getattr(torch, model_args.torch_dtype) ) world_size = int(os.environ.get("WORLD_SIZE", 1)) - if world_size > 1: + ddp = world_size != 1 + if ddp: model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0} + config = config_class.from_pretrained( model_args.model_name_or_path, torch_dtype=torch_dtype, @@ -426,6 +428,8 @@ def main(): model.print_trainable_parameters() else: logger.info("Full parameters training") + if model_args.model_type in ['chatglm']: + model = model.half() print_trainable_parameters(model) # Preprocessing the datasets. @@ -604,7 +608,7 @@ def group_texts(examples): else: model.config.use_cache = True model.enable_input_require_grads() - if torch.cuda.device_count() > 1: + if not ddp and torch.cuda.device_count() > 1: # Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available model.is_parallelizable = True model.model_parallel = True diff --git a/supervised_finetuning.py b/supervised_finetuning.py index e9f04f9..fe5d1ff 100644 --- a/supervised_finetuning.py +++ b/supervised_finetuning.py @@ -873,7 +873,8 @@ def preprocess_function(examples): model.print_trainable_parameters() else: logger.info("Fine-tuning method: Full parameters training") - model = model.float() + if model_args.model_type in ['chatglm']: + model = model.half() print_trainable_parameters(model) logger.debug(f"Model: {model}") @@ -883,11 +884,7 @@ def preprocess_function(examples): model.config.use_cache = False else: model.config.use_cache = True - - try: - model.enable_input_require_grads() - except: - logger.warning(f"Could not enable input require_grads on model, skipping.") + model.enable_input_require_grads() if not ddp and torch.cuda.device_count() > 1: # Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available model.is_parallelizable = True