Skip to content

Commit b7686cc

Browse files
Warn when kernel version is too low on Linux (#2077)
* Warn when kernel version is too low on Linux See #1929 On Linux with kernel version < 5.5, issues with hanging processes have been reported. It is not clear how to fix the issue, so instead we warn the user that they may encounter problems. Notes As logging requires an initialized PartialState, the actual check happens at the end of Accelerator.__init__. In a similar vein, the docstring of get_logger has been adjusted to first initialize the Accelerator, as it is not working as currently shown. * Reviewer comment: small change to docstring
1 parent f322987 commit b7686cc

File tree

5 files changed

+63
-3
lines changed

5 files changed

+63
-3
lines changed

src/accelerate/accelerator.py

+3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
ProjectConfiguration,
6464
RNGType,
6565
TorchDynamoPlugin,
66+
check_os_kernel,
6667
compare_versions,
6768
convert_model,
6869
convert_outputs_to_fp32,
@@ -470,6 +471,8 @@ def __init__(
470471
# Set a flag tensor for early stopping and other breakpoints
471472
self.flag_tensor = None
472473

474+
check_os_kernel()
475+
473476
@property
474477
def use_distributed(self):
475478
"""

src/accelerate/logging.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -85,19 +85,18 @@ def get_logger(name: str, log_level: str = None):
8585
8686
```python
8787
>>> from accelerate.logging import get_logger
88+
>>> from accelerate import Accelerator
8889
8990
>>> logger = get_logger(__name__)
9091
92+
>>> accelerator = Accelerator()
9193
>>> logger.info("My log", main_process_only=False)
9294
>>> logger.debug("My log", main_process_only=True)
9395
9496
>>> logger = get_logger(__name__, log_level="DEBUG")
9597
>>> logger.info("My log")
9698
>>> logger.debug("My second log")
9799
98-
>>> from accelerate import Accelerator
99-
100-
>>> accelerator = Accelerator()
101100
>>> array = ["a", "b", "c", "d"]
102101
>>> letter_at_rank = array[accelerator.process_index]
103102
>>> logger.info(letter_at_rank, in_order=True)

src/accelerate/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@
165165
from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
166166
from .memory import find_executable_batch_size, release_memory
167167
from .other import (
168+
check_os_kernel,
168169
clear_environment,
169170
convert_bytes,
170171
extract_model_from_parallel,

src/accelerate/utils/other.py

+25
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313
# limitations under the License.
1414

1515
import os
16+
import platform
17+
import re
1618
import socket
1719
from contextlib import contextmanager
1820
from functools import partial
1921
from types import MethodType
2022

2123
import torch
24+
from packaging.version import Version
2225

2326
from ..commands.config.default import write_basic_config # noqa: F401
27+
from ..logging import get_logger
2428
from ..state import PartialState
2529
from .constants import FSDP_PYTORCH_VERSION
2630
from .dataclasses import DistributedType
@@ -29,6 +33,9 @@
2933
from .versions import is_torch_version
3034

3135

36+
logger = get_logger(__name__)
37+
38+
3239
if is_tpu_available(check_device=False):
3340
import torch_xla.core.xla_model as xm
3441

@@ -252,3 +259,21 @@ def convert_bytes(size):
252259
size /= 1024.0
253260

254261
return f"{round(size, 2)} PB"
262+
263+
264+
def check_os_kernel():
265+
"""Warns if the kernel version is below the recommended minimum on Linux."""
266+
# see issue #1929
267+
info = platform.uname()
268+
system = info.system
269+
if system != "Linux":
270+
return
271+
272+
_, version, *_ = re.split(r"(\d+\.\d+\.\d+)", info.release)
273+
min_version = "5.5.0"
274+
if Version(version) < Version(min_version):
275+
msg = (
276+
f"Detected kernel version {version}, which is below the recommended minimum of {min_version}; this can "
277+
"cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher."
278+
)
279+
logger.warning(msg, main_process_only=True)

tests/test_utils.py

+32
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
import os
1616
import pickle
1717
import unittest
18+
import warnings
1819
from collections import UserDict, namedtuple
20+
from unittest.mock import Mock, patch
1921

2022
import torch
2123

24+
from accelerate.state import PartialState
2225
from accelerate.test_utils.testing import require_cuda, require_torch_min_version
2326
from accelerate.test_utils.training import RegressionModel
2427
from accelerate.utils import (
28+
check_os_kernel,
2529
convert_outputs_to_fp32,
2630
extract_model_from_parallel,
2731
find_device,
@@ -36,6 +40,10 @@
3640

3741

3842
class UtilsTester(unittest.TestCase):
43+
def setUp(self):
44+
# logging requires initialized state
45+
PartialState()
46+
3947
def test_send_to_device(self):
4048
tensor = torch.randn(5, 2)
4149
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@@ -173,3 +181,27 @@ def test_find_device(self):
173181
self.assertEqual(find_device([1, "a", torch.tensor([1, 2, 3])]), torch.device("cpu"))
174182
self.assertEqual(find_device({"a": 1, "b": torch.tensor([1, 2, 3])}), torch.device("cpu"))
175183
self.assertIsNone(find_device([1, "a"]))
184+
185+
def test_check_os_kernel_no_warning_when_release_gt_min(self):
186+
# min version is 5.5
187+
with patch("platform.uname", return_value=Mock(release="5.15.0-35-generic", system="Linux")):
188+
with warnings.catch_warnings(record=True) as w:
189+
check_os_kernel()
190+
self.assertEqual(len(w), 0)
191+
192+
def test_check_os_kernel_no_warning_when_not_linux(self):
193+
# system must be Linux
194+
with patch("platform.uname", return_value=Mock(release="5.4.0-35-generic", system="Darwin")):
195+
with warnings.catch_warnings(record=True) as w:
196+
check_os_kernel()
197+
self.assertEqual(len(w), 0)
198+
199+
def test_check_os_kernel_warning_when_release_lt_min(self):
200+
# min version is 5.5
201+
with patch("platform.uname", return_value=Mock(release="5.4.0-35-generic", system="Linux")):
202+
with self.assertLogs() as ctx:
203+
check_os_kernel()
204+
self.assertEqual(len(ctx.records), 1)
205+
self.assertEqual(ctx.records[0].levelname, "WARNING")
206+
self.assertIn("5.4.0", ctx.records[0].msg)
207+
self.assertIn("5.5.0", ctx.records[0].msg)

0 commit comments

Comments
 (0)