Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/lightning/pytorch/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Any
from typing import Any, Optional

import lightning.pytorch as pl
from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator
Expand Down Expand Up @@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:

"""
raise NotImplementedError

@classmethod
def device_name(cls, device: Optional[_DEVICE] = None) -> str:
"""Get the device name for a given device."""
return str(cls.is_available())
7 changes: 7 additions & 0 deletions src/lightning/pytorch/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@classmethod
@override
def device_name(cls, device: Optional[_DEVICE] = None) -> str:
if not cls.is_available():
return ""
return torch.cuda.get_device_name(device)


def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/pytorch/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@classmethod
@override
def device_name(cls, device: Optional[_DEVICE] = None) -> str:
# todo: implement a better way to get the device name
if not cls.is_available():
return ""
return "True (mps)"


# device metrics
_VM_PERCENT = "M1_vm_percent"
Expand Down
25 changes: 24 additions & 1 deletion src/lightning/pytorch/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional

from typing_extensions import override

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
from lightning.fabric.accelerators.xla import XLAAccelerator as FabricXLAAccelerator
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -53,3 +54,25 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register("tpu", cls, description=cls.__name__)

@classmethod
@override
def device_name(cls, device: Optional[_DEVICE] = None) -> str:
is_available = cls.is_available()
if not is_available:
return ""

if _XLA_GREATER_EQUAL_2_1:
from torch_xla._internal import tpu
else:
from torch_xla.experimental import tpu
import torch_xla.core.xla_env_vars as xenv
from requests.exceptions import HTTPError

try:
ret = tpu.get_tpu_env()[xenv.ACCELERATOR_TYPE]
except HTTPError:
# Fallback to "True" if HTTPError is raised during retrieving device information
ret = str(is_available)

return ret
21 changes: 12 additions & 9 deletions src/lightning/pytorch/trainer/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,20 +152,23 @@ def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str

def _log_device_info(trainer: "pl.Trainer") -> None:
if CUDAAccelerator.is_available():
gpu_available = True
gpu_type = " (cuda)"
if isinstance(trainer.accelerator, CUDAAccelerator):
device_name = ", ".join(list({CUDAAccelerator.device_name(d) for d in trainer.device_ids}))
else:
device_name = CUDAAccelerator.device_name()
elif MPSAccelerator.is_available():
gpu_available = True
gpu_type = " (mps)"
device_name = MPSAccelerator.device_name()
else:
gpu_available = False
gpu_type = ""
device_name = str(False)

gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator))
rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}")
gpu_used = trainer.num_devices if isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) else 0
rank_zero_info(f"GPU available: {device_name}, using: {gpu_used} devices.")

num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0
rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores")
rank_zero_info(
f"TPU available: {XLAAccelerator.device_name() if XLAAccelerator.is_available() else str(False)}, "
f"using: {num_tpu_cores} TPU cores"
)

if _habana_available_and_importable():
from lightning_habana import HPUAccelerator
Expand Down
13 changes: 13 additions & 0 deletions tests/tests_pytorch/accelerators/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,16 @@ def test_gpu_availability():
def test_warning_if_gpus_not_used(cuda_count_1):
with pytest.warns(UserWarning, match="GPU available but not used"):
Trainer(accelerator="cpu")


@RunIf(min_cuda_gpus=1)
def test_gpu_device_name():
for i in range(torch.cuda.device_count()):
assert torch.cuda.get_device_name(i) == CUDAAccelerator.device_name(i)

with torch.device("cuda:0"):
assert torch.cuda.get_device_name(0) == CUDAAccelerator.device_name()


def test_gpu_device_name_no_gpu(cuda_count_0):
assert CUDAAccelerator.device_name() == ""
11 changes: 11 additions & 0 deletions tests/tests_pytorch/accelerators/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from collections import namedtuple
from unittest import mock

import pytest
import torch
Expand All @@ -39,6 +40,16 @@ def test_mps_availability():
assert MPSAccelerator.is_available()


@RunIf(mps=True)
def test_mps_device_name():
assert MPSAccelerator.device_name() == "True (mps)"


def test_mps_device_name_not_available():
with mock.patch("torch.backends.mps.is_available", return_value=False):
assert MPSAccelerator.device_name() == ""


def test_warning_if_mps_not_used(mps_count_1):
with pytest.warns(UserWarning, match="GPU available but not used"):
Trainer(accelerator="cpu")
Expand Down
13 changes: 13 additions & 0 deletions tests/tests_pytorch/accelerators/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,19 @@ def test_warning_if_tpus_not_used(tpu_available):
Trainer(accelerator="cpu")


@RunIf(tpu=True)
def test_tpu_device_name():
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1

if _XLA_GREATER_EQUAL_2_1:
from torch_xla._internal import tpu
else:
from torch_xla.experimental import tpu
import torch_xla.core.xla_env_vars as xenv

assert XLAAccelerator.device_name() == tpu.get_tpu_env()[xenv.ACCELERATOR_TYPE]


@pytest.mark.parametrize(
("devices", "expected_device_ids"),
[
Expand Down
6 changes: 6 additions & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def thread_police_duuu_daaa_duuu_daaa():
def mock_cuda_count(monkeypatch, n: int) -> None:
monkeypatch.setattr(lightning.fabric.accelerators.cuda, "num_cuda_devices", lambda: n)
monkeypatch.setattr(lightning.pytorch.accelerators.cuda, "num_cuda_devices", lambda: n)
monkeypatch.setattr(torch.cuda, "get_device_name", lambda _: "Mocked CUDA Device")


@pytest.fixture
Expand Down Expand Up @@ -244,6 +245,11 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
monkeypatch.setitem(sys.modules, "torch_xla", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
monkeypatch.setattr(
lightning.pytorch.accelerators.xla.XLAAccelerator,
"device_name",
lambda *_: "Mocked TPU Device",
)


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/plugins/test_cluster_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_ranks_available_manual_strategy_selection(_, strategy_cls):
"""Test that the rank information is readily available after Trainer initialization."""
num_nodes = 2
for cluster, variables, expected in environment_combinations():
with mock.patch.dict(os.environ, variables):
with mock.patch.dict(os.environ, variables), mock.patch("torch.cuda.get_device_name", return_value="GPU"):
strategy = strategy_cls(
parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)], cluster_environment=cluster
)
Expand Down
Loading