Skip to content

Commit

Permalink
use the correct available memory API for XPU (#3076)
Browse files Browse the repository at this point in the history
* fix

* update

* remove blank line

* update

* add check

* add  imports

* warning for both

* reformat
  • Loading branch information
faaany authored Sep 9, 2024
1 parent e7e0181 commit 4b4c036
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
18 changes: 17 additions & 1 deletion src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@
A collection of utilities for ensuring that training can always occur. Heavily influenced by the
[toma](https://github.com/BlackHC/toma) library.
"""

import functools
import gc
import importlib
import inspect
import warnings

import torch

from .imports import (
is_cuda_available,
is_ipex_available,
is_mlu_available,
is_mps_available,
is_musa_available,
is_npu_available,
is_xpu_available,
)
from .versions import compare_versions


def clear_device_cache(garbage_collection=False):
Expand Down Expand Up @@ -159,3 +162,16 @@ def decorator(*args, **kwargs):
raise

return decorator


def get_xpu_available_memory(device_index: int):
if is_ipex_available():
ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
if compare_versions(ipex_version, ">=", "2.5"):
from intel_extension_for_pytorch.xpu import mem_get_info

return mem_get_info(device_index)[0]
warnings.warn(
"The XPU `mem_get_info` API is available in IPEX version >=2.5. The current returned available memory is incorrect. Please consider upgrading your IPEX version."
)
return torch.xpu.max_memory_allocated(device_index)
4 changes: 2 additions & 2 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
is_torch_xla_available,
is_xpu_available,
)
from .memory import clear_device_cache
from .memory import clear_device_cache, get_xpu_available_memory
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm
from .versions import compare_versions, is_torch_version
Expand Down Expand Up @@ -915,7 +915,7 @@ def get_max_memory(max_memory: Optional[Dict[Union[int, str], Union[int, str]]]
for i in range(torch.xpu.device_count()):
try:
_ = torch.tensor(0, device=torch.device("xpu", i))
max_memory[i] = torch.xpu.max_memory_allocated(i)
max_memory[i] = get_xpu_available_memory(i)
except Exception:
logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
continue
Expand Down

0 comments on commit 4b4c036

Please sign in to comment.