Skip to content

Commit

Permalink
update half.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jul 27, 2023
1 parent 6560984 commit 8404c07
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
8 changes: 6 additions & 2 deletions pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,10 @@ def main():
else getattr(torch, model_args.torch_dtype)
)
world_size = int(os.environ.get("WORLD_SIZE", 1))
if world_size > 1:
ddp = world_size != 1
if ddp:
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}

config = config_class.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch_dtype,
Expand Down Expand Up @@ -426,6 +428,8 @@ def main():
model.print_trainable_parameters()
else:
logger.info("Full parameters training")
if model_args.model_type in ['chatglm']:
model = model.half()
print_trainable_parameters(model)

# Preprocessing the datasets.
Expand Down Expand Up @@ -604,7 +608,7 @@ def group_texts(examples):
else:
model.config.use_cache = True
model.enable_input_require_grads()
if torch.cuda.device_count() > 1:
if not ddp and torch.cuda.device_count() > 1:
# Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
Expand Down
9 changes: 3 additions & 6 deletions supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,8 @@ def preprocess_function(examples):
model.print_trainable_parameters()
else:
logger.info("Fine-tuning method: Full parameters training")
model = model.float()
if model_args.model_type in ['chatglm']:
model = model.half()
print_trainable_parameters(model)
logger.debug(f"Model: {model}")

Expand All @@ -883,11 +884,7 @@ def preprocess_function(examples):
model.config.use_cache = False
else:
model.config.use_cache = True

try:
model.enable_input_require_grads()
except:
logger.warning(f"Could not enable input require_grads on model, skipping.")
model.enable_input_require_grads()
if not ddp and torch.cuda.device_count() > 1:
# Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
Expand Down

0 comments on commit 8404c07

Please sign in to comment.