15
15
16
16
logger = get_logger (__name__ )
17
17
18
- DTYPE_MAP = dict (bfloat16 = torch .bfloat16 , float16 = torch .float16 , float32 = torch .float32 , auto = "auto" )
19
-
18
+ DTYPE_MAP = dict (bfloat16 = torch .bfloat16 , float16 = torch .float16 , float32 = torch .float32 , int8 = torch .int8 , auto = "auto" )
20
19
21
20
def main ():
22
21
parser = argparse .ArgumentParser (description = "Load bloom layers and convert to 8-bit using torch quantization." )
23
22
24
23
parser .add_argument ("--model" , type = str , default = "bigscience/bloom-6b3" , help = "Model name for from_pretrained" )
25
24
parser .add_argument ("--revision" , type = str , default = None , help = "Optional commit id from HF hub" )
26
- parser .add_argument ("--torch_dtype" , type = str , default = "auto" , help = "Load initial model in this dtype" )
27
- parser .add_argument ("--output_path" , type = str , default = "./converted_model" , help = "Track output repo to this folder" )
25
+ parser .add_argument ("--torch_dtype" , type = str , choices = DTYPE_MAP .keys (), default = "auto" ,
26
+ help = "Load initial model in this dtype" )
27
+ parser .add_argument ("--output_path" , type = str , default = "./converted_model" ,
28
+ help = "Track output repo to this folder" )
28
29
parser .add_argument ("--output_repo" , type = str , default = "bigscience/test-bloomd" , help = "Push to this HF hub repo" )
29
30
parser .add_argument ("--client_branch" , type = str , default = CLIENT_BRANCH , help = "Save client version to this branch" )
30
31
parser .add_argument (
@@ -41,7 +42,6 @@ def main():
41
42
if args .model == "bigscience/bloom" and free_ram_gb < 400 :
42
43
logger .warning (f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have { free_ram_gb :.3f} free" )
43
44
44
- assert args .torch_dtype in DTYPE_MAP , f"torch_dtype must be one of { list (DTYPE_MAP .keys ())} "
45
45
if os .path .exists (args .output_path ) and (
46
46
len (os .listdir (args .output_path )) != 0 or not os .path .isdir (args .output_path )
47
47
):
@@ -54,8 +54,15 @@ def main():
54
54
config .dht_prefix = args .output_repo
55
55
56
56
model = BloomModel .from_pretrained (
57
- args .model , use_auth_token = args .use_auth_token , revision = args .revision , torch_dtype = DTYPE_MAP [args .torch_dtype ]
57
+ args .model , use_auth_token = args .use_auth_token , revision = args .revision ,
58
+ torch_dtype = DTYPE_MAP [args .torch_dtype ] if args .torch_dtype != "int8" else "float16" ,
59
+ load_in_8bit = args .torch_dtype == "int8" ,
60
+ device_map = {"word_embeddings" : "cuda" , "word_embeddings_layernorm" : "cuda" , "h" : "cuda" , "ln_f" : "cuda" }
58
61
)
62
+ if args .torch_dtype == "int8" :
63
+ # trigger weight quantization
64
+ model = model .cuda ()
65
+
59
66
if args .resize_token_embeddings :
60
67
logger .info (f"Resizing token embeddings, new size = { args .resize_token_embeddings } " )
61
68
model .resize_token_embeddings (args .resize_token_embeddings )
0 commit comments