Skip to content

Commit 56a3bee

Browse files
committed
Remove load_in_8bit from convert_block
1 parent 9bee2b7 commit 56a3bee

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/petals/server/throughput.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from petals.bloom.block import WrappedBloomBlock
1515
from petals.server.block_utils import resolve_block_dtype
16-
from petals.utils.convert_block import convert_block
16+
from petals.utils.convert_block import convert_block, replace_8bit_linear
1717
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
1818

1919
logger = get_logger(__name__)
@@ -149,7 +149,9 @@ def measure_compute_rps(
149149
tensor_parallel_devices = (device,)
150150
with torch.inference_mode():
151151
block = WrappedBloomBlock(config).to(dtype)
152-
block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
152+
if load_in_8bit:
153+
block = replace_8bit_linear(block)
154+
block = convert_block(block, config, tensor_parallel_devices, device, freeze=True)
153155

154156
cache = None
155157
elapsed = 0

0 commit comments

Comments
 (0)