|
17 | 17 |
|
18 | 18 | DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int8=torch.int8, auto="auto")
|
19 | 19 |
|
| 20 | + |
20 | 21 | def main():
|
21 | 22 | parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
22 | 23 |
|
23 | 24 | parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
24 | 25 | parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
25 |
| - parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", |
26 |
| - help="Load initial model in this dtype") |
27 |
| - parser.add_argument("--output_path", type=str, default="./converted_model", |
28 |
| - help="Track output repo to this folder") |
| 26 | + parser.add_argument( |
| 27 | + "--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Load initial model in this dtype" |
| 28 | + ) |
| 29 | + parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder") |
29 | 30 | parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
30 | 31 | parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
31 | 32 | parser.add_argument(
|
@@ -54,10 +55,12 @@ def main():
|
54 | 55 | config.dht_prefix = args.output_repo
|
55 | 56 |
|
56 | 57 | model = BloomModel.from_pretrained(
|
57 |
| - args.model, use_auth_token=args.use_auth_token, revision=args.revision, |
| 58 | + args.model, |
| 59 | + use_auth_token=args.use_auth_token, |
| 60 | + revision=args.revision, |
58 | 61 | torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
|
59 | 62 | load_in_8bit=args.torch_dtype == "int8",
|
60 |
| - device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"} |
| 63 | + device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"}, |
61 | 64 | )
|
62 | 65 | if args.torch_dtype == "int8":
|
63 | 66 | # trigger weight quantization
|
|
0 commit comments