Skip to content

Commit 79781a2

Browse files
authored
Update py_predictor_v2.py
1 parent 852678d commit 79781a2

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

metapredict/backend/py_predictor_v2.py

+4
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def __init__(self, saved_weights, dtype, gpuid='cpu'):
7979
if torch.cuda.is_available():
8080
device_string = f"cuda:{gpuid}"
8181
device = torch.device(device_string)
82+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
83+
# Use MPS if available on ARM-based MacBooks
84+
device_string = "mps"
85+
device = torch.device(device_string)
8286
else:
8387
device_string = "cpu"
8488
device = torch.device(device_string)

0 commit comments

Comments
 (0)