Skip to content

Commit 6398eef

Browse files
authored
Fix pytest-xdist compatibility in test_torch_ops (#2542)
This PR replaces the direct torch operator of `pytest.mark.parametrize` with `getattr` in `test_torch_ops` to support parallel testing via pytest-xdist.
1 parent 3d6fd01 commit 6398eef

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7673,10 +7673,11 @@ class TestVarStd(TorchBaseTest):
76737673
@pytest.mark.parametrize(
76747674
"compute_unit, backend, frontend, torch_op, unbiased",
76757675
itertools.product(
7676-
compute_units, backends, frontends, [torch.var, torch.std], [True, False]
7676+
compute_units, backends, frontends, ["var", "std"], [True, False]
76777677
),
76787678
)
76797679
def test_var_std_2_inputs(self, compute_unit, backend, frontend, torch_op, unbiased):
7680+
torch_op = getattr(torch, torch_op)
76807681
model = ModuleWrapper(function=torch_op, kwargs={"unbiased": unbiased})
76817682
x = torch.randn(1, 5, 10) * 3
76827683
out = torch_op(x, unbiased=unbiased).unsqueeze(0)
@@ -7696,7 +7697,7 @@ def test_var_std_2_inputs(self, compute_unit, backend, frontend, torch_op, unbia
76967697
compute_units,
76977698
backends,
76987699
frontends,
7699-
[torch.var, torch.std],
7700+
["var", "std"],
77007701
[True, False],
77017702
[[0, 2], [1], [2]],
77027703
[True, False],
@@ -7705,6 +7706,7 @@ def test_var_std_2_inputs(self, compute_unit, backend, frontend, torch_op, unbia
77057706
def test_var_std_4_inputs(
77067707
self, compute_unit, backend, frontend, torch_op, unbiased, dim, keepdim
77077708
):
7709+
torch_op = getattr(torch, torch_op)
77087710
model = ModuleWrapper(
77097711
function=torch_op,
77107712
kwargs={"unbiased": unbiased, "dim": dim, "keepdim": keepdim},
@@ -7720,7 +7722,7 @@ def test_var_std_4_inputs(
77207722
compute_units,
77217723
backends,
77227724
frontends,
7723-
[torch.var, torch.std],
7725+
["var", "std"],
77247726
[0, 1],
77257727
[[0, 2], [1], [2]],
77267728
[True, False],
@@ -7729,6 +7731,7 @@ def test_var_std_4_inputs(
77297731
def test_var_std_with_correction(
77307732
self, compute_unit, backend, frontend, torch_op, correction, dim, keepdim
77317733
):
7734+
torch_op = getattr(torch, torch_op)
77327735
model = ModuleWrapper(
77337736
function=torch_op,
77347737
kwargs={"correction": correction, "dim": dim, "keepdim": keepdim},
@@ -9103,34 +9106,35 @@ def generate_tensor_rank_5(self, x):
91039106
backends,
91049107
frontends,
91059108
[
9106-
torch.abs,
9107-
torch.acos,
9108-
torch.asin,
9109-
torch.atan,
9110-
torch.atanh,
9111-
torch.ceil,
9112-
torch.cos,
9113-
torch.cosh,
9114-
torch.exp,
9115-
torch.exp2,
9116-
torch.floor,
9117-
torch.log,
9118-
torch.log2,
9119-
torch.round,
9120-
torch.rsqrt,
9121-
torch.sign,
9122-
torch.sin,
9123-
torch.sinh,
9124-
torch.sqrt,
9125-
torch.square,
9126-
torch.tan,
9127-
torch.tanh,
9109+
"abs",
9110+
"acos",
9111+
"asin",
9112+
"atan",
9113+
"atanh",
9114+
"ceil",
9115+
"cos",
9116+
"cosh",
9117+
"exp",
9118+
"exp2",
9119+
"floor",
9120+
"log",
9121+
"log2",
9122+
"round",
9123+
"rsqrt",
9124+
"sign",
9125+
"sin",
9126+
"sinh",
9127+
"sqrt",
9128+
"square",
9129+
"tan",
9130+
"tanh",
91289131
],
91299132
),
91309133
)
91319134
def test_torch_rank0_tensor(self, compute_unit, backend, frontend, torch_op):
9132-
if frontend == TorchFrontend.EXECUTORCH and torch_op == torch.exp2:
9135+
if frontend == TorchFrontend.EXECUTORCH and torch_op == "exp2":
91339136
pytest.skip("torch._ops.aten.exp2.default is not Aten Canonical")
9137+
torch_op = getattr(torch, torch_op)
91349138

91359139
class Model(nn.Module):
91369140
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -12658,7 +12662,7 @@ class TestSTFT(TorchBaseTest):
1265812662
[16], # n_fft
1265912663
[None, 4, 5], # hop_length
1266012664
[None, 16, 9], # win_length
12661-
[None, torch.hann_window], # window
12665+
[None, "hann_window"], # window
1266212666
[None, False, True], # center
1266312667
["constant", "reflect", "replicate"], # pad mode
1266412668
[False, True], # normalized
@@ -12668,6 +12672,8 @@ class TestSTFT(TorchBaseTest):
1266812672
def test_stft(self, compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
1266912673
if complex and onesided:
1267012674
pytest.skip("Onesided stft not possible for complex inputs")
12675+
if window is not None:
12676+
window = getattr(torch, window)
1267112677

1267212678
class STFTModel(torch.nn.Module):
1267312679
def forward(self, x):

0 commit comments

Comments
 (0)