We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b06d795 commit 9bee2b7Copy full SHA for 9bee2b7
src/petals/cli/convert_model.py
@@ -60,7 +60,7 @@ def main():
60
revision=args.revision,
61
torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
62
load_in_8bit=args.torch_dtype == "int8",
63
- device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"},
+ device_map="auto" if args.torch_dtype == "int8" else None,
64
)
65
if args.torch_dtype == "int8":
66
# trigger weight quantization
0 commit comments