|
1 | 1 | import torch
|
2 | 2 | import unittest
|
3 | 3 | import itertools
|
| 4 | +import warnings |
4 | 5 | from math import inf, nan, isnan
|
5 | 6 | from random import randrange
|
6 | 7 |
|
7 | 8 | from torch.testing._internal.common_utils import \
|
8 | 9 | (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor)
|
9 | 10 | from torch.testing._internal.common_device_type import \
|
10 |
| - (instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) |
| 11 | + (instantiate_device_type_tests, dtypes, dtypesIfCUDA, |
| 12 | + onlyCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) |
11 | 13 | from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args
|
12 | 14 | from torch.autograd import gradcheck
|
13 | 15 |
|
@@ -914,6 +916,126 @@ def test_nuclear_norm_exceptions_old(self, device):
|
914 | 916 | self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0))
|
915 | 917 | self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2))
|
916 | 918 |
|
| 919 | + @skipCUDAIfNoMagma |
| 920 | + @skipCPUIfNoLapack |
| 921 | + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) |
| 922 | + @dtypesIfCUDA(torch.float, torch.double) |
| 923 | + @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) |
| 924 | + def test_tensorsolve(self, device, dtype): |
| 925 | + def run_test(a_shape, dims): |
| 926 | + a = torch.randn(a_shape, dtype=dtype, device=device) |
| 927 | + b = torch.randn(a_shape[:2], dtype=dtype, device=device) |
| 928 | + result = torch.linalg.tensorsolve(a, b, dims=dims) |
| 929 | + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) |
| 930 | + self.assertEqual(result, expected) |
| 931 | + |
| 932 | + # check the out= variant |
| 933 | + out = torch.empty_like(result) |
| 934 | + ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out) |
| 935 | + self.assertEqual(ans, out) |
| 936 | + self.assertEqual(ans, result) |
| 937 | + |
| 938 | + a_shapes = [(2, 3, 6), (3, 4, 4, 3)] |
| 939 | + dims = [None, (0, 2)] |
| 940 | + for a_shape, d in itertools.product(a_shapes, dims): |
| 941 | + run_test(a_shape, d) |
| 942 | + |
| 943 | + @skipCUDAIfNoMagma |
| 944 | + @skipCPUIfNoLapack |
| 945 | + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) |
| 946 | + @dtypesIfCUDA(torch.float, torch.double) |
| 947 | + def test_tensorsolve_empty(self, device, dtype): |
| 948 | + # Check for empty inputs. NumPy does not work for these cases. |
| 949 | + a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device) |
| 950 | + b = torch.empty(a.shape[:2], dtype=dtype, device=device) |
| 951 | + x = torch.linalg.tensorsolve(a, b) |
| 952 | + self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b) |
| 953 | + |
| 954 | + # TODO: once "solve_cuda" supports complex dtypes, they shall be added to above tests |
| 955 | + @unittest.expectedFailure |
| 956 | + @onlyCUDA |
| 957 | + @skipCUDAIfNoMagma |
| 958 | + @dtypes(torch.cfloat, torch.cdouble) |
| 959 | + def test_tensorsolve_xfailed(self, device, dtype): |
| 960 | + a_shape = (2, 3, 6) |
| 961 | + a = torch.randn(a_shape, dtype=dtype, device=device) |
| 962 | + b = torch.randn(a_shape[:2], dtype=dtype, device=device) |
| 963 | + result = torch.linalg.tensorsolve(a, b) |
| 964 | + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy()) |
| 965 | + self.assertEqual(result, expected) |
| 966 | + |
| 967 | + @skipCUDAIfNoMagma |
| 968 | + @skipCPUIfNoLapack |
| 969 | + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) |
| 970 | + @dtypesIfCUDA(torch.float, torch.double) |
| 971 | + @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) |
| 972 | + def test_tensorsolve_non_contiguous(self, device, dtype): |
| 973 | + def run_test_permuted(a_shape, dims): |
| 974 | + # check for permuted / transposed inputs |
| 975 | + a = torch.randn(a_shape, dtype=dtype, device=device) |
| 976 | + a = a.movedim((0, 2), (-2, -1)) |
| 977 | + self.assertFalse(a.is_contiguous()) |
| 978 | + b = torch.randn(a.shape[:2], dtype=dtype, device=device) |
| 979 | + b = b.t() |
| 980 | + self.assertFalse(b.is_contiguous()) |
| 981 | + result = torch.linalg.tensorsolve(a, b, dims=dims) |
| 982 | + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) |
| 983 | + self.assertEqual(result, expected) |
| 984 | + |
| 985 | + def run_test_skipped_elements(a_shape, dims): |
| 986 | + # check for inputs with skipped elements |
| 987 | + a = torch.randn(a_shape, dtype=dtype, device=device) |
| 988 | + a = a[::2] |
| 989 | + self.assertFalse(a.is_contiguous()) |
| 990 | + b = torch.randn(a_shape[:2], dtype=dtype, device=device) |
| 991 | + b = b[::2] |
| 992 | + self.assertFalse(b.is_contiguous()) |
| 993 | + result = torch.linalg.tensorsolve(a, b, dims=dims) |
| 994 | + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) |
| 995 | + self.assertEqual(result, expected) |
| 996 | + |
| 997 | + # check non-contiguous out |
| 998 | + out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2] |
| 999 | + self.assertFalse(out.is_contiguous()) |
| 1000 | + ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out) |
| 1001 | + self.assertEqual(ans, out) |
| 1002 | + self.assertEqual(ans, result) |
| 1003 | + |
| 1004 | + a_shapes = [(2, 3, 6), (3, 4, 4, 3)] |
| 1005 | + dims = [None, (0, 2)] |
| 1006 | + for a_shape, d in itertools.product(a_shapes, dims): |
| 1007 | + run_test_permuted(a_shape, d) |
| 1008 | + |
| 1009 | + a_shapes = [(4, 3, 6), (6, 4, 4, 3)] |
| 1010 | + dims = [None, (0, 2)] |
| 1011 | + for a_shape, d in itertools.product(a_shapes, dims): |
| 1012 | + run_test_skipped_elements(a_shape, d) |
| 1013 | + |
| 1014 | + @skipCUDAIfNoMagma |
| 1015 | + @skipCPUIfNoLapack |
| 1016 | + @dtypes(torch.float32) |
| 1017 | + def test_tensorsolve_errors_and_warnings(self, device, dtype): |
| 1018 | + # tensorsolve expects the input that can be reshaped to a square matrix |
| 1019 | + a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) |
| 1020 | + b = torch.randn(8, 4) |
| 1021 | + self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape)) |
| 1022 | + with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'): |
| 1023 | + torch.linalg.tensorsolve(a, b) |
| 1024 | + |
| 1025 | + # if non-empty out tensor with wrong shape is passed a warning is given |
| 1026 | + out = torch.empty_like(a) |
| 1027 | + b = torch.randn(6, 4) |
| 1028 | + with warnings.catch_warnings(record=True) as w: |
| 1029 | + # Trigger warning |
| 1030 | + torch.linalg.tensorsolve(a, b, out=out) |
| 1031 | + # Check warning occurs |
| 1032 | + self.assertEqual(len(w), 1) |
| 1033 | + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) |
| 1034 | + |
| 1035 | + # dtypes should match |
| 1036 | + out = torch.empty_like(a).to(torch.int) |
| 1037 | + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): |
| 1038 | + torch.linalg.tensorsolve(a, b, out=out) |
917 | 1039 |
|
918 | 1040 | instantiate_device_type_tests(TestLinalg, globals())
|
919 | 1041 |
|
|
0 commit comments