1
- # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
2
2
# SPDX-License-Identifier: Apache-2.0
3
3
#
4
4
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -131,7 +131,8 @@ def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
131
131
tokenizer .pad_token = tokenizer .eos_token
132
132
if tokenizer .pad_token is None :
133
133
tokenizer .pad_token = tokenizer .eos_token
134
- assert tokenizer .pad_token is not None , f"Pad token for { model_type } cannot be set!"
134
+ assert (tokenizer .pad_token
135
+ is not None ), f"Pad token for { model_type } cannot be set!"
135
136
136
137
return tokenizer
137
138
@@ -158,9 +159,9 @@ def get_model(ckpt_path, dtype="fp16", device="cuda"):
158
159
159
160
model_dtype = next (model .parameters ()).dtype
160
161
if dtype != model_dtype :
161
- print (
162
- f"[TensorRT-LLM][WARNING] The manually set model data type is { dtype } , "
163
- f"but the data type of the HuggingFace model is { model_dtype } ." )
162
+ print ("[TensorRT-LLM][WARNING] The manually set model data type is "
163
+ f" { dtype } , but the data type of the HuggingFace model is "
164
+ f" { model_dtype } ." )
164
165
165
166
return model
166
167
@@ -244,15 +245,13 @@ def main(args):
244
245
else :
245
246
if "awq" in args .qformat :
246
247
if args .calib_size > 32 :
247
- print (
248
- f"AWQ calibration could take longer with calib_size = { args .calib_size } , Using"
249
- " calib_size=32 instead" )
248
+ print ("AWQ calibration could take longer with calib_size = "
249
+ f"{ args .calib_size } , Using calib_size=32 instead" )
250
250
args .calib_size = 32
251
- print (
252
- "\n AWQ calibration could take longer than other calibration methods. Please"
253
- " increase the batch size to speed up the calibration process. Batch size can be"
254
- " set by adding the argument --batch_size <batch_size> to the command line.\n "
255
- )
251
+ print ("\n AWQ calibration could take longer than other calibration "
252
+ "methods. Please increase the batch size to speed up the "
253
+ "calibration process. Batch size can be set by adding the "
254
+ "argument --batch_size <batch_size> to the command line.\n " )
256
255
257
256
calib_dataloader = get_calib_dataloader (
258
257
tokenizer = tokenizer ,
@@ -287,9 +286,8 @@ def main(args):
287
286
288
287
with torch .inference_mode ():
289
288
if model_type is None :
290
- print (
291
- f"Unknown model type { type (model ).__name__ } . Continue exporting..."
292
- )
289
+ print (f"Unknown model type { type (model ).__name__ } . Continue "
290
+ "exporting..." )
293
291
model_type = f"unknown:{ type (model ).__name__ } "
294
292
295
293
export_path = args .output_dir
0 commit comments