Skip to content

Commit fd9400b

Browse files
authored
Fix use_chunked_forward="auto" on non-x86_64 machines (#267)
Import of cpufeature may crash on non-x86_64 machines, so this PR makes the client import it only if necessary.
1 parent a2e7f27 commit fd9400b

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/petals/bloom/modeling_utils.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
See commit history for authorship.
55
"""
66

7+
import platform
8+
79
import psutil
810
import torch
911
import torch.nn.functional as F
1012
import torch.utils.checkpoint
11-
from cpufeature import CPUFeature
1213
from hivemind import get_logger
1314
from torch import nn
1415
from transformers import BloomConfig
@@ -29,9 +30,15 @@ def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
2930

3031
self.use_chunked_forward = config.use_chunked_forward
3132
if self.use_chunked_forward == "auto":
32-
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
33-
# Otherwise, it's ~8x slower.
34-
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
33+
if platform.machine() == "x86_64":
34+
# Import of cpufeature may crash on non-x86_64 machines
35+
from cpufeature import CPUFeature
36+
37+
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
38+
# Otherwise, it's ~8x slower.
39+
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
40+
else:
41+
self.use_chunked_forward = True
3542
self.chunked_forward_step = config.chunked_forward_step
3643
self._bf16_warning_shown = False
3744

0 commit comments

Comments
 (0)