diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 4cab69e63918..fe657924be55 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -356,6 +356,17 @@ } } +DISABLED_TORCH_TESTS_TPUVM_ONLY = { + # test_nn.py + 'TestNNDeviceTypeXLA': { + 'test_AdaptiveMaxPool1d_indices_xla', # TODO: segfualt on TPUVM + 'test_AdaptiveMaxPool2d_indices_xla', # TODO: segfualt on TPUVM + 'test_AdaptiveMaxPool3d_indices_xla', # TODO: segfualt on TPUVM + 'test_MaxPool3d_indices_xla', # TODO: segfualt on TPUVM + 'test_multi_margin_loss_errors_xla', # TODO: segfualt on TPUVM + }, +} + DISABLED_TORCH_TESTS_GPU_ONLY = { # test_torch.py 'TestTorchDeviceTypeXLA': { @@ -406,10 +417,18 @@ def union_of_disabled_tests(sets): return union +def on_tpuvm(): + config = os.getenv('XRT_TPU_CONFIG') + return config and re.match('^localservice;[0-9]+;localhost:[0-9]+', config) + + DISABLED_TORCH_TESTS_CPU = DISABLED_TORCH_TESTS_ANY DISABLED_TORCH_TESTS_GPU = union_of_disabled_tests( [DISABLED_TORCH_TESTS_ANY, DISABLED_TORCH_TESTS_GPU_ONLY]) -DISABLED_TORCH_TESTS_TPU = union_of_disabled_tests( +DISABLED_TORCH_TESTS_TPU = union_of_disabled_tests([ + DISABLED_TORCH_TESTS_ANY, DISABLED_TORCH_TESTS_TPU_ONLY, + DISABLED_TORCH_TESTS_TPUVM_ONLY +]) if on_tpuvm() else union_of_disabled_tests( [DISABLED_TORCH_TESTS_ANY, DISABLED_TORCH_TESTS_TPU_ONLY]) DISABLED_TORCH_TESTS = {