Skip to content

Commit 9bee2b7

Browse files
committed
Set device_map only for int8
1 parent b06d795 commit 9bee2b7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/petals/cli/convert_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def main():
6060
revision=args.revision,
6161
torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
6262
load_in_8bit=args.torch_dtype == "int8",
63-
device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"},
63+
device_map="auto" if args.torch_dtype == "int8" else None,
6464
)
6565
if args.torch_dtype == "int8":
6666
# trigger weight quantization

0 commit comments

Comments
 (0)