Skip to content

Commit e011d4a

Browse files
zou3519gchanan
authored andcommitted
Restore CUDA half linspace+logspace and add coverage tests (pytorch#31959)
This PR restores the implementation of CUDA half linspace+logspace. I added tests for the following: - linspace+logspace have the same support for integral types on CPU/CUDA - Precision tests for CUDA half, float, and double. The precision for CUDA half seems bad, but I checked the numbers against previous versions of pytorch. The output of CUDA Half linspace+logspace are exactly the same when compared with 1.2.0. Equivalent-ish PR on master: pytorch#31962
1 parent 8ada95e commit e011d4a

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

aten/src/ATen/native/cuda/RangeFactories.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Tensor& linspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step
5252
} else if (steps == 1) {
5353
r.fill_(start);
5454
} else {
55-
AT_DISPATCH_FLOATING_TYPES(r.scalar_type(), "linspace_cuda", [&]() {
55+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(r.scalar_type(), "linspace_cuda", [&]() {
5656
scalar_t scalar_start = start.to<scalar_t>();
5757
scalar_t scalar_end = end.to<scalar_t>();
5858
scalar_t step = (scalar_end - scalar_start) / static_cast<scalar_t>(steps - 1);
@@ -84,7 +84,7 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step
8484
} else if (steps == 1) {
8585
r.fill_(std::pow(base, start.to<double>()));
8686
} else {
87-
AT_DISPATCH_FLOATING_TYPES(r.scalar_type(), "logspace_cuda", [&]() {
87+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(r.scalar_type(), "logspace_cuda", [&]() {
8888
scalar_t scalar_base = static_cast<scalar_t>(base);
8989
scalar_t scalar_start = start.to<scalar_t>();
9090
scalar_t scalar_end = end.to<scalar_t>();

test/test_torch.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13707,24 +13707,70 @@ def test_cat_big(self, device):
1370713707
result = torch.cat(concat_list)
1370813708
self.assertEqual(result.size(0), SIZE1 + SIZE2)
1370913709

13710+
# NOTE [Linspace+Logspace precision override]
13711+
# Our Linspace and logspace torch.half CUDA kernels are not very precise.
13712+
# Since linspace/logspace are deterministic, we can compute an expected
13713+
# amount of error (by testing without a precision override), adding a tiny
13714+
# amount (EPS) to that, and using that value as the override.
13715+
LINSPACE_LOGSPACE_EXTRA_EPS = 1e-5
13716+
1371013717
# Tests that compare a device's computation with the (gold-standard) CPU's.
1371113718
class TestDevicePrecision(TestCase):
13712-
def test_linspace(self, device):
13713-
a = torch.linspace(0, 10, 10, device=device)
13714-
b = torch.linspace(0, 10, 10)
13719+
13720+
# The implementation of linspace+logspace goes through a different path
13721+
# when the steps arg is equal to 0 or 1. For other values of `steps`
13722+
# they call specialized linspace (or logspace) kernels.
13723+
LINSPACE_LOGSPACE_SPECIAL_STEPS = [0, 1]
13724+
13725+
def _test_linspace(self, device, dtype, steps):
13726+
a = torch.linspace(0, 10, steps=steps, dtype=dtype, device=device)
13727+
b = torch.linspace(0, 10, steps=steps)
1371513728
self.assertEqual(a, b)
1371613729

13717-
@dtypes(torch.double)
13718-
def test_logspace(self, device, dtype):
13719-
a = torch.logspace(1, 10, 10, dtype=dtype, device=device)
13720-
b = torch.logspace(1, 10, 10, dtype=dtype, device='cpu')
13730+
# See NOTE [Linspace+Logspace precision override]
13731+
@precisionOverride({torch.half: 0.0039 + LINSPACE_LOGSPACE_EXTRA_EPS})
13732+
@dtypesIfCUDA(torch.half, torch.float, torch.double)
13733+
@dtypes(torch.float, torch.double)
13734+
def test_linspace(self, device, dtype):
13735+
self._test_linspace(device, dtype, steps=10)
13736+
13737+
@dtypesIfCUDA(torch.half, torch.float, torch.double)
13738+
@dtypes(torch.float, torch.double)
13739+
def test_linspace_special_steps(self, device, dtype):
13740+
for steps in self.LINSPACE_LOGSPACE_SPECIAL_STEPS:
13741+
self._test_linspace(device, dtype, steps=steps)
13742+
13743+
def _test_logspace(self, device, dtype, steps):
13744+
a = torch.logspace(1, 1.1, steps=steps, dtype=dtype, device=device)
13745+
b = torch.logspace(1, 1.1, steps=steps)
1372113746
self.assertEqual(a, b)
1372213747

13723-
# Check non-default base=2
13724-
a = torch.logspace(1, 10, 10, 2, dtype=dtype, device=device)
13725-
b = torch.logspace(1, 10, 10, 2, dtype=dtype, device='cpu')
13748+
def _test_logspace_base2(self, device, dtype, steps):
13749+
a = torch.logspace(1, 1.1, steps=steps, base=2, dtype=dtype, device=device)
13750+
b = torch.logspace(1, 1.1, steps=steps, base=2)
1372613751
self.assertEqual(a, b)
1372713752

13753+
# See NOTE [Linspace+Logspace precision override]
13754+
@precisionOverride({torch.half: 0.0157 + LINSPACE_LOGSPACE_EXTRA_EPS})
13755+
@dtypesIfCUDA(torch.half, torch.float, torch.double)
13756+
@dtypes(torch.float, torch.double)
13757+
def test_logspace(self, device, dtype):
13758+
self._test_logspace(device, dtype, steps=10)
13759+
13760+
# See NOTE [Linspace+Logspace precision override]
13761+
@precisionOverride({torch.half: 0.00201 + LINSPACE_LOGSPACE_EXTRA_EPS})
13762+
@dtypesIfCUDA(torch.half, torch.float, torch.double)
13763+
@dtypes(torch.float, torch.double)
13764+
def test_logspace_base2(self, device, dtype):
13765+
self._test_logspace_base2(device, dtype, steps=10)
13766+
13767+
@dtypesIfCUDA(torch.half, torch.float, torch.double)
13768+
@dtypes(torch.float, torch.double)
13769+
def test_logspace_special_steps(self, device, dtype):
13770+
for steps in self.LINSPACE_LOGSPACE_SPECIAL_STEPS:
13771+
self._test_logspace(device, dtype, steps=steps)
13772+
self._test_logspace_base2(device, dtype, steps=steps)
13773+
1372813774
# Note: ROCm fails when using float tensors
1372913775
@dtypes(torch.double)
1373013776
def test_polygamma(self, device, dtype):

0 commit comments

Comments
 (0)