|
15 | 15 | import os
|
16 | 16 | import pickle
|
17 | 17 | import unittest
|
| 18 | +import warnings |
18 | 19 | from collections import UserDict, namedtuple
|
| 20 | +from unittest.mock import Mock, patch |
19 | 21 |
|
20 | 22 | import torch
|
21 | 23 |
|
| 24 | +from accelerate.state import PartialState |
22 | 25 | from accelerate.test_utils.testing import require_cuda, require_torch_min_version
|
23 | 26 | from accelerate.test_utils.training import RegressionModel
|
24 | 27 | from accelerate.utils import (
|
| 28 | + check_os_kernel, |
25 | 29 | convert_outputs_to_fp32,
|
26 | 30 | extract_model_from_parallel,
|
27 | 31 | find_device,
|
|
36 | 40 |
|
37 | 41 |
|
38 | 42 | class UtilsTester(unittest.TestCase):
|
| 43 | + def setUp(self): |
| 44 | + # logging requires initialized state |
| 45 | + PartialState() |
| 46 | + |
39 | 47 | def test_send_to_device(self):
|
40 | 48 | tensor = torch.randn(5, 2)
|
41 | 49 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
@@ -173,3 +181,27 @@ def test_find_device(self):
|
173 | 181 | self.assertEqual(find_device([1, "a", torch.tensor([1, 2, 3])]), torch.device("cpu"))
|
174 | 182 | self.assertEqual(find_device({"a": 1, "b": torch.tensor([1, 2, 3])}), torch.device("cpu"))
|
175 | 183 | self.assertIsNone(find_device([1, "a"]))
|
| 184 | + |
| 185 | + def test_check_os_kernel_no_warning_when_release_gt_min(self): |
| 186 | + # min version is 5.5 |
| 187 | + with patch("platform.uname", return_value=Mock(release="5.15.0-35-generic", system="Linux")): |
| 188 | + with warnings.catch_warnings(record=True) as w: |
| 189 | + check_os_kernel() |
| 190 | + self.assertEqual(len(w), 0) |
| 191 | + |
| 192 | + def test_check_os_kernel_no_warning_when_not_linux(self): |
| 193 | + # system must be Linux |
| 194 | + with patch("platform.uname", return_value=Mock(release="5.4.0-35-generic", system="Darwin")): |
| 195 | + with warnings.catch_warnings(record=True) as w: |
| 196 | + check_os_kernel() |
| 197 | + self.assertEqual(len(w), 0) |
| 198 | + |
| 199 | + def test_check_os_kernel_warning_when_release_lt_min(self): |
| 200 | + # min version is 5.5 |
| 201 | + with patch("platform.uname", return_value=Mock(release="5.4.0-35-generic", system="Linux")): |
| 202 | + with self.assertLogs() as ctx: |
| 203 | + check_os_kernel() |
| 204 | + self.assertEqual(len(ctx.records), 1) |
| 205 | + self.assertEqual(ctx.records[0].levelname, "WARNING") |
| 206 | + self.assertIn("5.4.0", ctx.records[0].msg) |
| 207 | + self.assertIn("5.5.0", ctx.records[0].msg) |
0 commit comments