@@ -13707,24 +13707,70 @@ def test_cat_big(self, device):
13707
13707
result = torch.cat(concat_list)
13708
13708
self.assertEqual(result.size(0), SIZE1 + SIZE2)
13709
13709
13710
+ # NOTE [Linspace+Logspace precision override]
13711
+ # Our Linspace and logspace torch.half CUDA kernels are not very precise.
13712
+ # Since linspace/logspace are deterministic, we can compute an expected
13713
+ # amount of error (by testing without a precision override), adding a tiny
13714
+ # amount (EPS) to that, and using that value as the override.
13715
+ LINSPACE_LOGSPACE_EXTRA_EPS = 1e-5
13716
+
13710
13717
# Tests that compare a device's computation with the (gold-standard) CPU's.
13711
13718
class TestDevicePrecision(TestCase):
13712
- def test_linspace(self, device):
13713
- a = torch.linspace(0, 10, 10, device=device)
13714
- b = torch.linspace(0, 10, 10)
13719
+
13720
+ # The implementation of linspace+logspace goes through a different path
13721
+ # when the steps arg is equal to 0 or 1. For other values of `steps`
13722
+ # they call specialized linspace (or logspace) kernels.
13723
+ LINSPACE_LOGSPACE_SPECIAL_STEPS = [0, 1]
13724
+
13725
+ def _test_linspace(self, device, dtype, steps):
13726
+ a = torch.linspace(0, 10, steps=steps, dtype=dtype, device=device)
13727
+ b = torch.linspace(0, 10, steps=steps)
13715
13728
self.assertEqual(a, b)
13716
13729
13717
- @dtypes(torch.double)
13718
- def test_logspace(self, device, dtype):
13719
- a = torch.logspace(1, 10, 10, dtype=dtype, device=device)
13720
- b = torch.logspace(1, 10, 10, dtype=dtype, device='cpu')
13730
+ # See NOTE [Linspace+Logspace precision override]
13731
+ @precisionOverride({torch.half: 0.0039 + LINSPACE_LOGSPACE_EXTRA_EPS})
13732
+ @dtypesIfCUDA(torch.half, torch.float, torch.double)
13733
+ @dtypes(torch.float, torch.double)
13734
+ def test_linspace(self, device, dtype):
13735
+ self._test_linspace(device, dtype, steps=10)
13736
+
13737
+ @dtypesIfCUDA(torch.half, torch.float, torch.double)
13738
+ @dtypes(torch.float, torch.double)
13739
+ def test_linspace_special_steps(self, device, dtype):
13740
+ for steps in self.LINSPACE_LOGSPACE_SPECIAL_STEPS:
13741
+ self._test_linspace(device, dtype, steps=steps)
13742
+
13743
+ def _test_logspace(self, device, dtype, steps):
13744
+ a = torch.logspace(1, 1.1, steps=steps, dtype=dtype, device=device)
13745
+ b = torch.logspace(1, 1.1, steps=steps)
13721
13746
self.assertEqual(a, b)
13722
13747
13723
- # Check non-default base=2
13724
- a = torch.logspace(1, 10, 10, 2, dtype=dtype, device=device)
13725
- b = torch.logspace(1, 10, 10, 2, dtype=dtype, device='cpu' )
13748
+ def _test_logspace_base2(self, device, dtype, steps):
13749
+ a = torch.logspace(1, 1.1, steps=steps, base= 2, dtype=dtype, device=device)
13750
+ b = torch.logspace(1, 1.1, steps=steps, base=2 )
13726
13751
self.assertEqual(a, b)
13727
13752
13753
+ # See NOTE [Linspace+Logspace precision override]
13754
+ @precisionOverride({torch.half: 0.0157 + LINSPACE_LOGSPACE_EXTRA_EPS})
13755
+ @dtypesIfCUDA(torch.half, torch.float, torch.double)
13756
+ @dtypes(torch.float, torch.double)
13757
+ def test_logspace(self, device, dtype):
13758
+ self._test_logspace(device, dtype, steps=10)
13759
+
13760
+ # See NOTE [Linspace+Logspace precision override]
13761
+ @precisionOverride({torch.half: 0.00201 + LINSPACE_LOGSPACE_EXTRA_EPS})
13762
+ @dtypesIfCUDA(torch.half, torch.float, torch.double)
13763
+ @dtypes(torch.float, torch.double)
13764
+ def test_logspace_base2(self, device, dtype):
13765
+ self._test_logspace_base2(device, dtype, steps=10)
13766
+
13767
+ @dtypesIfCUDA(torch.half, torch.float, torch.double)
13768
+ @dtypes(torch.float, torch.double)
13769
+ def test_logspace_special_steps(self, device, dtype):
13770
+ for steps in self.LINSPACE_LOGSPACE_SPECIAL_STEPS:
13771
+ self._test_logspace(device, dtype, steps=steps)
13772
+ self._test_logspace_base2(device, dtype, steps=steps)
13773
+
13728
13774
# Note: ROCm fails when using float tensors
13729
13775
@dtypes(torch.double)
13730
13776
def test_polygamma(self, device, dtype):
0 commit comments