Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions model2vec/distill/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,47 @@
from logging import getLogger

import torch
from packaging import version

logger = getLogger(__name__)


def select_optimal_device(device: str | None) -> str:
"""
Guess what your optimal device should be based on backend availability.
Get the optimal device to use based on backend availability.

If you pass a device, we just pass it through.
For Torch versions >= 2.8.0, MPS is disabled due to known performance regressions.

:param device: The device to use. If this is not None you get back what you passed.
:param device: The device to use. If this is None, the device is automatically selected.
:return: The selected device.
:raises RuntimeError: If MPS is requested on a PyTorch version where it is disabled.
"""
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
# Get the torch version and check if MPS is broken
torch_version = version.parse(torch.__version__.split("+")[0])
mps_broken = torch_version >= version.parse("2.8.0")

if device:
if device == "mps" and mps_broken:
raise RuntimeError(
f"MPS is disabled for PyTorch {torch.__version__} due to known performance regressions. "
"Please use CPU or CUDA instead, or use a PyTorch version < 2.8.0."
)
else:
return device

if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
if mps_broken:
logger.warning(
f"MPS is available but PyTorch {torch.__version__} has known performance regressions. "
"Falling back to CPU. Please use a PyTorch version < 2.8.0 to enable MPS support."
)
device = "cpu"
logger.info(f"Automatically selected device: {device}")
else:
device = "mps"
else:
device = "cpu"

logger.info(f"Automatically selected device: {device}")
return device
44 changes: 24 additions & 20 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
from __future__ import annotations

import json
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any
from tempfile import NamedTemporaryFile
from unittest.mock import patch

import numpy as np
import pytest
import safetensors
import safetensors.numpy
from tokenizers import Tokenizer

from model2vec.distill.utils import select_optimal_device
from model2vec.hf_utils import _get_metadata_from_readme
Expand Down Expand Up @@ -39,26 +33,36 @@ def test__get_metadata_from_readme_mocked_file_keys() -> None:


@pytest.mark.parametrize(
"device, expected, cuda, mps",
"torch_version, device, expected, cuda, mps, should_raise",
[
("cpu", "cpu", True, True),
("cpu", "cpu", True, False),
("cpu", "cpu", False, True),
("cpu", "cpu", False, False),
("clown", "clown", False, False),
(None, "cuda", True, True),
(None, "cuda", True, False),
(None, "mps", False, True),
(None, "cpu", False, False),
("2.7.0", "cpu", "cpu", True, True, False),
("2.8.0", "cpu", "cpu", True, True, False),
("2.7.0", "clown", "clown", False, False, False),
("2.8.0", "clown", "clown", False, False, False),
("2.7.0", "mps", "mps", False, True, False),
("2.8.0", "mps", None, False, True, True),
("2.7.0", None, "cuda", True, True, False),
("2.7.0", None, "mps", False, True, False),
("2.7.0", None, "cpu", False, False, False),
("2.8.0", None, "cuda", True, True, False),
("2.8.0", None, "cpu", False, True, False),
("2.8.0", None, "cpu", False, False, False),
("2.9.0", None, "cpu", False, True, False),
("3.0.0", None, "cpu", False, True, False),
],
)
def test_select_optimal_device(device: str | None, expected: str, cuda: bool, mps: bool) -> None:
"""Test whether the optimal device is selected."""
def test_select_optimal_device(torch_version, device, expected, cuda, mps, should_raise) -> None:
"""Test whether the optimal device is selected across versions and backends."""
with (
patch("torch.cuda.is_available", return_value=cuda),
patch("torch.backends.mps.is_available", return_value=mps),
patch("torch.__version__", torch_version),
):
assert select_optimal_device(device) == expected
if should_raise:
with pytest.raises(RuntimeError):
select_optimal_device(device)
else:
assert select_optimal_device(device) == expected


def test_importable() -> None:
Expand Down