Skip to content

Commit 275f6e9

Browse files
authored
Disable AMP by default on CPU (#9218)
Co-authored-by: Haifeng Jin <[email protected]>
1 parent 3e556dc commit 275f6e9

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

benchmarks/torchbench_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ def is_accelerator_tpu(self):
373373
return self.benchmark_experiment.accelerator == "tpu"
374374

375375
def use_amp(self):
376+
# AMP is only supported on cuda and tpu, not on cpu.
377+
if self.benchmark_experiment.accelerator == "cpu":
378+
logger.warning("AMP is not used due to running on CPU.")
379+
return False
376380
return self.is_training() or self.model_name in config(
377381
).dtype.force_amp_for_fp16_bf16_models
378382

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import unittest
2+
3+
from benchmarks.torchbench_model import TorchBenchModel
4+
5+
6+
class MockExperiment:
7+
8+
def __init__(self, accelerator, test):
9+
self.accelerator = accelerator
10+
self.test = "train"
11+
12+
13+
class TorchBenchModelTest(unittest.TestCase):
14+
15+
def test_do_not_use_amp_on_cpu_and_warns(self):
16+
experiment = MockExperiment("cpu", "train")
17+
model = TorchBenchModel("torchbench or other", "super_deep_model",
18+
experiment)
19+
with self.assertLogs('benchmarks.torchbench_model', level='WARNING') as cm:
20+
use_amp = model.use_amp()
21+
self.assertEqual(len(cm.output), 1)
22+
self.assertIn("AMP is not used", cm.output[0])
23+
self.assertFalse(use_amp)
24+
25+
def test_use_amp_on_cuda(self):
26+
experiment = MockExperiment("cuda", "train")
27+
model = TorchBenchModel("torchbench or other", "super_deep_model",
28+
experiment)
29+
self.assertTrue(model.use_amp())
30+
31+
32+
if __name__ == '__main__':
33+
unittest.main()

0 commit comments

Comments
 (0)