Skip to content

Fix deepspeed with quantization #37324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3719,19 +3719,14 @@ def float(self, *args):

@classmethod
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
# With deepspeed, we cannot initialize the model on meta device
if is_deepspeed_zero3_enabled():
init_contexts = [no_init_weights()]
# We cannot initialize the model on meta device with deepspeed when not quantized
if not is_quantized and not _is_ds_init_called:
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts.extend(
[
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
set_zero3_state(),
]
)
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
elif is_quantized:
init_contexts.append(set_quantized_state())
init_contexts.extend([init_empty_weights(), set_quantized_state()])
else:
init_contexts = [no_init_weights(), init_empty_weights()]

Expand Down Expand Up @@ -4800,7 +4795,11 @@ def _load_pretrained_model(
continue

map_location = "cpu"
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb and not is_deepspeed_zero3_enabled():
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"
elif (
device_map is not None
Expand All @@ -4822,7 +4821,7 @@ def _load_pretrained_model(
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}

if is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
Expand Down