Skip to content

Commit 4c1b705

Browse files
committed
Support saving and loading 8-bit block weights
1 parent fd9400b commit 4c1b705

File tree

5 files changed

+27
-20
lines changed

5 files changed

+27
-20
lines changed

src/petals/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from petals.client import *
22
from petals.utils.logging import initialize_logs as _initialize_logs
33

4-
__version__ = "1.1.2"
4+
__version__ = "1.1.3"
55

66
_initialize_logs()

src/petals/bloom/from_pretrained.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
from petals.bloom.block import WrappedBloomBlock
2222
from petals.server.block_utils import get_block_size
2323
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
24+
from petals.utils.convert_block import replace_8bit_linear
2425

2526
logger = get_logger(__name__)
2627

2728
CLIENT_BRANCH = "main"
28-
BLOCK_BRANCH_PREFIX = "block_"
29+
BLOCK_BRANCH_PREFIX = "int8_block"
2930

3031

3132
def load_pretrained_block(
@@ -36,6 +37,8 @@ def load_pretrained_block(
3637
use_auth_token: Optional[str] = None,
3738
cache_dir: Optional[str] = None,
3839
max_disk_space: Optional[int] = None,
40+
load_in_8bit=False,
41+
device: Optional[Union[str, torch.device]] = None,
3942
) -> WrappedBloomBlock:
4043
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
4144

@@ -45,6 +48,10 @@ def load_pretrained_block(
4548
cache_dir = DEFAULT_CACHE_DIR
4649

4750
block = WrappedBloomBlock(config)
51+
if load_in_8bit:
52+
block = replace_8bit_linear(block)
53+
block = block.to(device)
54+
4855
state_dict = _load_state_dict(
4956
converted_model_name_or_path,
5057
block_index,

src/petals/cli/convert_model.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@
1515

1616
logger = get_logger(__name__)
1717

18-
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
19-
18+
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int8=torch.int8, auto="auto")
2019

2120
def main():
2221
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
2322

2423
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
2524
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
26-
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
27-
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
25+
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
26+
help="Load initial model in this dtype")
27+
parser.add_argument("--output_path", type=str, default="./converted_model",
28+
help="Track output repo to this folder")
2829
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
2930
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
3031
parser.add_argument(
@@ -41,7 +42,6 @@ def main():
4142
if args.model == "bigscience/bloom" and free_ram_gb < 400:
4243
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
4344

44-
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
4545
if os.path.exists(args.output_path) and (
4646
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
4747
):
@@ -54,8 +54,15 @@ def main():
5454
config.dht_prefix = args.output_repo
5555

5656
model = BloomModel.from_pretrained(
57-
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
57+
args.model, use_auth_token=args.use_auth_token, revision=args.revision,
58+
torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
59+
load_in_8bit=args.torch_dtype == "int8",
60+
device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"}
5861
)
62+
if args.torch_dtype == "int8":
63+
# trigger weight quantization
64+
model = model.cuda()
65+
5966
if args.resize_token_embeddings:
6067
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
6168
model.resize_token_embeddings(args.resize_token_embeddings)

src/petals/server/server.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,10 @@ def create(
407407
use_auth_token=use_auth_token,
408408
cache_dir=cache_dir,
409409
max_disk_space=max_disk_space,
410+
load_in_8bit=load_in_8bit,
411+
device=device,
410412
)
411-
block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
413+
block = convert_block(block, block_config, tensor_parallel_devices, device, freeze=True)
412414

413415
backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
414416
blocks[module_uid] = TransformerBackend(

src/petals/utils/convert_block.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,17 @@ def convert_block(
2323
config: BloomConfig,
2424
tensor_parallel_devices: Sequence[torch.device],
2525
output_device: torch.device,
26-
load_in_8bit: bool,
27-
threshold: float = 6.0,
2826
freeze: bool = True,
2927
) -> tp.TensorParallel:
3028
"""
31-
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
29+
Optimize a transformer block for use in a Petals server and apply tensor parallelism
3230
3331
:note: some optimizations will modify the input block in-place!
3432
:param block: a single transformer block, either pre-trained or newly initialized
3533
:param config: HF transformers config for the full model
3634
:param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
3735
:note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
3836
:param output_device: if tensor_parallel_devices is True, output
39-
:param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint
40-
:param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 )
4137
:param freeze: if True (default), make all module parameters non-trainable
4238
:return: a module that acts like the original block, but runs with all specified optimizations
4339
@@ -48,9 +44,6 @@ def convert_block(
4844

4945
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
5046

51-
if load_in_8bit:
52-
block = replace_8bit_linear(block, threshold=threshold)
53-
5447
for shard, device in zip(block.module_shards, block.devices):
5548
shard.to(device)
5649

@@ -77,15 +70,13 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
7770
# Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
7871
import bitsandbytes as bnb
7972

80-
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
81-
8273
for n, module in model.named_children():
8374
if len(list(module.children())) > 0:
8475
replace_8bit_linear(module, threshold)
8576

8677
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
8778
assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
88-
model._modules[n] = CustomLinear8bitLt(
79+
model._modules[n] = bnb.nn.Linear8bitLt(
8980
module.in_features,
9081
module.out_features,
9182
module.bias is not None,

0 commit comments

Comments
 (0)