Skip to content

Commit b06d795

Browse files
committed
Fix formatting
1 parent 4c1b705 commit b06d795

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

src/petals/bloom/from_pretrained.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
from petals.bloom.block import WrappedBloomBlock
2222
from petals.server.block_utils import get_block_size
23-
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
2423
from petals.utils.convert_block import replace_8bit_linear
24+
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
2525

2626
logger = get_logger(__name__)
2727

src/petals/cli/convert_model.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@
1717

1818
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int8=torch.int8, auto="auto")
1919

20+
2021
def main():
2122
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
2223

2324
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
2425
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
25-
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
26-
help="Load initial model in this dtype")
27-
parser.add_argument("--output_path", type=str, default="./converted_model",
28-
help="Track output repo to this folder")
26+
parser.add_argument(
27+
"--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Load initial model in this dtype"
28+
)
29+
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
2930
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
3031
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
3132
parser.add_argument(
@@ -54,10 +55,12 @@ def main():
5455
config.dht_prefix = args.output_repo
5556

5657
model = BloomModel.from_pretrained(
57-
args.model, use_auth_token=args.use_auth_token, revision=args.revision,
58+
args.model,
59+
use_auth_token=args.use_auth_token,
60+
revision=args.revision,
5861
torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
5962
load_in_8bit=args.torch_dtype == "int8",
60-
device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"}
63+
device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"},
6164
)
6265
if args.torch_dtype == "int8":
6366
# trigger weight quantization

0 commit comments

Comments
 (0)