Skip to content

Commit 2b16299

Browse files
Remove torch.jit.fuser("fuser2") in test (#7069)
* [WIP] Remove torch.jit.fuser("fuser2") in test Internally we're considering removing support for fuser2, so we need to remove this special case from the test. * completely remove special-casing
1 parent 35f68a0 commit 2b16299

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

test/test_ops.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -1555,13 +1555,7 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
15551555
torch.random.manual_seed(seed)
15561556
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
15571557
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1558-
if device == "cpu":
1559-
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1560-
else:
1561-
with torch.jit.fuser("fuser2"):
1562-
# Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476
1563-
# We may remove this condition once the bug is resolved
1564-
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1558+
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
15651559

15661560
tol = 1e-3 if dtype is torch.half else 1e-5
15671561
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)

0 commit comments

Comments
 (0)