Skip to content

Commit f629fbe

Browse files
IvanYashchukfacebook-github-bot
authored andcommitted
Added torch.linalg.tensorsolve (pytorch#46142)
Summary: This PR adds `torch.linalg.tensorsolve` function that matches `numpy.linalg.tensorsolve`. Ref pytorch#42666. Pull Request resolved: pytorch#46142 Reviewed By: izdeby Differential Revision: D24539400 Pulled By: mruberry fbshipit-source-id: 6e38364fe0bc511e739036deb274d9307df119b2
1 parent 13b4127 commit f629fbe

File tree

9 files changed

+264
-6
lines changed

9 files changed

+264
-6
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ static void apply_solve(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
331331
auto batch_size = batchCount(A);
332332
auto n = A.size(-2);
333333
auto nrhs = b.size(-1);
334+
auto lda = std::max(int64_t{1}, n);
334335

335336
auto ipiv = at::empty({n}, b.options().dtype(kInt));
336337
auto ipiv_data = ipiv.data_ptr<int>();
@@ -339,7 +340,7 @@ static void apply_solve(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
339340
for (int64_t i = 0; i < batch_size; i++) {
340341
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
341342
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
342-
lapackSolve<scalar_t>(n, nrhs, A_working_ptr, n, ipiv_data, b_working_ptr, n, &info);
343+
lapackSolve<scalar_t>(n, nrhs, A_working_ptr, lda, ipiv_data, b_working_ptr, lda, &info);
343344
infos[i] = info;
344345
if (info != 0) {
345346
return;

aten/src/ATen/native/LinearAlgebra.cpp

+49
Original file line numberDiff line numberDiff line change
@@ -1602,6 +1602,55 @@ Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, opt
16021602
return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype);
16031603
}
16041604

1605+
Tensor linalg_tensorsolve(const Tensor& self, const Tensor& other, optional<IntArrayRef> dims) {
1606+
/*
1607+
The idea is to reduce the problem to 2D matrix solve.
1608+
Step 1. (optional) `self` is permuted with `dims` such that dimensions from `dims` are moved to the right.
1609+
For example, if we have 4D input with the shape (1, 2, 3, 4) and dims=(0, 2),
1610+
then the result of permutation would have the shape (2, 4, 1, 3).
1611+
Step 2. reshape `self` to 2D matrix.
1612+
Step 3. solve the matrix equation self.to_2D() @ result = other.to_1D()
1613+
Step 4. reshape the result.
1614+
*/
1615+
int64_t ndim = self.dim();
1616+
Tensor self_ = self;
1617+
1618+
// move dimensions of `self_` from `dims` to the end
1619+
if (dims.has_value()) {
1620+
DimVector dest_axes(dims.value().size());
1621+
std::iota(dest_axes.begin(), dest_axes.end(), ndim - dest_axes.size());
1622+
self_ = at::movedim(self_, dims.value(), dest_axes);
1623+
}
1624+
1625+
// result_shape is self_.sizes[-(an-other.dim):]
1626+
std::vector<int64_t> result_shape = self_.sizes().slice(other.dim(), ndim - other.dim()).vec();
1627+
1628+
int64_t result_product = std::accumulate(result_shape.begin(), result_shape.end(), int64_t{1}, std::multiplies<int64_t>());
1629+
int64_t other_product = std::accumulate(other.sizes().begin(), other.sizes().end(), int64_t{1}, std::multiplies<int64_t>());
1630+
1631+
// Check whether the self tensor can be reshaped to the 2D square matrix
1632+
TORCH_CHECK(result_product == other_product,
1633+
"Expected self to satisfy the requirement prod(self.shape[other.ndim:]) == prod(self.shape[:other.ndim]), but got ",
1634+
result_product, " != ", other_product);
1635+
1636+
self_ = self_.reshape({result_product, result_product});
1637+
1638+
// 0th output of at::solve is the solution
1639+
// normally `other` would be flattened by at::solve expects 2D input
1640+
Tensor result = std::get<0>(at::solve(other.reshape({other.numel(), 1}), self_));
1641+
return result.reshape(result_shape);
1642+
}
1643+
1644+
Tensor& linalg_tensorsolve_out(Tensor& result, const Tensor& self, const Tensor& other, optional<IntArrayRef> dims) {
1645+
TORCH_CHECK(result.scalar_type() == self.scalar_type(),
1646+
"result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type());
1647+
1648+
Tensor result_tmp = at::linalg_tensorsolve(self, other, dims);
1649+
at::native::resize_output(result, result_tmp.sizes());
1650+
result.copy_(result_tmp);
1651+
return result;
1652+
}
1653+
16051654
static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
16061655
if (i == j)
16071656
return matrices[i];

aten/src/ATen/native/cuda/BatchLinearAlgebra.cu

+5-4
Original file line numberDiff line numberDiff line change
@@ -840,12 +840,13 @@ AT_ERROR("solve: MAGMA library not found in "
840840
auto b_data = b.data_ptr<scalar_t>();
841841
magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)");
842842
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
843+
magma_int_t lda = std::max(magma_int_t{1}, n);
843844

844845
if (b.dim() == 2) {
845846
auto ipiv = at::empty({n}, at::kInt);
846847
magma_int_t info = 0;
847-
magmaSolve<scalar_t>(n, nrhs, A_data, n, ipiv.data_ptr<magma_int_t>(),
848-
b_data, n, &info);
848+
magmaSolve<scalar_t>(n, nrhs, A_data, lda, ipiv.data_ptr<magma_int_t>(),
849+
b_data, lda, &info);
849850
infos[0] = info;
850851
} else {
851852
auto A_mat_stride = matrixStride(A);
@@ -885,15 +886,15 @@ AT_ERROR("solve: MAGMA library not found in "
885886
magma_int_t* info_array_cur = &info_array[mini_idx];
886887

887888
magmaSolveBatched<scalar_t>(
888-
n, nrhs, A_array_cur, n, ipiv_array_cur, b_array_cur, n,
889+
n, nrhs, A_array_cur, lda, ipiv_array_cur, b_array_cur, lda,
889890
info_array_cur, batch_limit, magma_queue);
890891
}
891892

892893
// Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit
893894
// which concisely is equal to batch_size % batch_limit
894895
if (batch_size % batch_limit != 0) {
895896
magmaSolveBatched<scalar_t>(
896-
n, nrhs, &A_array[mini_idx], n, &ipiv_array[mini_idx], &b_array[mini_idx], n,
897+
n, nrhs, &A_array[mini_idx], lda, &ipiv_array[mini_idx], &b_array[mini_idx], lda,
897898
&info_array[mini_idx], batch_size % batch_limit, magma_queue);
898899
}
899900

aten/src/ATen/native/native_functions.yaml

+12
Original file line numberDiff line numberDiff line change
@@ -8875,6 +8875,18 @@
88758875
python_module: linalg
88768876
variants: function
88778877

8878+
- func: linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor
8879+
python_module: linalg
8880+
variants: function
8881+
dispatch:
8882+
Math: linalg_tensorsolve
8883+
8884+
- func: linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!)
8885+
python_module: linalg
8886+
variants: function
8887+
dispatch:
8888+
Math: linalg_tensorsolve_out
8889+
88788890
## Functions that are only for testing
88798891
# It is undocumented and should not be used outside of tests.
88808892
- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor

docs/source/linalg.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ Functions
1414

1515
.. autofunction:: det
1616
.. autofunction:: norm
17+
.. autofunction:: tensorsolve

test/test_linalg.py

+123-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import torch
22
import unittest
33
import itertools
4+
import warnings
45
from math import inf, nan, isnan
56
from random import randrange
67

78
from torch.testing._internal.common_utils import \
89
(TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor)
910
from torch.testing._internal.common_device_type import \
10-
(instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride)
11+
(instantiate_device_type_tests, dtypes, dtypesIfCUDA,
12+
onlyCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride)
1113
from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args
1214
from torch.autograd import gradcheck
1315

@@ -914,6 +916,126 @@ def test_nuclear_norm_exceptions_old(self, device):
914916
self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0))
915917
self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2))
916918

919+
@skipCUDAIfNoMagma
920+
@skipCPUIfNoLapack
921+
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
922+
@dtypesIfCUDA(torch.float, torch.double)
923+
@precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
924+
def test_tensorsolve(self, device, dtype):
925+
def run_test(a_shape, dims):
926+
a = torch.randn(a_shape, dtype=dtype, device=device)
927+
b = torch.randn(a_shape[:2], dtype=dtype, device=device)
928+
result = torch.linalg.tensorsolve(a, b, dims=dims)
929+
expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims)
930+
self.assertEqual(result, expected)
931+
932+
# check the out= variant
933+
out = torch.empty_like(result)
934+
ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out)
935+
self.assertEqual(ans, out)
936+
self.assertEqual(ans, result)
937+
938+
a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
939+
dims = [None, (0, 2)]
940+
for a_shape, d in itertools.product(a_shapes, dims):
941+
run_test(a_shape, d)
942+
943+
@skipCUDAIfNoMagma
944+
@skipCPUIfNoLapack
945+
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
946+
@dtypesIfCUDA(torch.float, torch.double)
947+
def test_tensorsolve_empty(self, device, dtype):
948+
# Check for empty inputs. NumPy does not work for these cases.
949+
a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device)
950+
b = torch.empty(a.shape[:2], dtype=dtype, device=device)
951+
x = torch.linalg.tensorsolve(a, b)
952+
self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b)
953+
954+
# TODO: once "solve_cuda" supports complex dtypes, they shall be added to above tests
955+
@unittest.expectedFailure
956+
@onlyCUDA
957+
@skipCUDAIfNoMagma
958+
@dtypes(torch.cfloat, torch.cdouble)
959+
def test_tensorsolve_xfailed(self, device, dtype):
960+
a_shape = (2, 3, 6)
961+
a = torch.randn(a_shape, dtype=dtype, device=device)
962+
b = torch.randn(a_shape[:2], dtype=dtype, device=device)
963+
result = torch.linalg.tensorsolve(a, b)
964+
expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy())
965+
self.assertEqual(result, expected)
966+
967+
@skipCUDAIfNoMagma
968+
@skipCPUIfNoLapack
969+
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
970+
@dtypesIfCUDA(torch.float, torch.double)
971+
@precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
972+
def test_tensorsolve_non_contiguous(self, device, dtype):
973+
def run_test_permuted(a_shape, dims):
974+
# check for permuted / transposed inputs
975+
a = torch.randn(a_shape, dtype=dtype, device=device)
976+
a = a.movedim((0, 2), (-2, -1))
977+
self.assertFalse(a.is_contiguous())
978+
b = torch.randn(a.shape[:2], dtype=dtype, device=device)
979+
b = b.t()
980+
self.assertFalse(b.is_contiguous())
981+
result = torch.linalg.tensorsolve(a, b, dims=dims)
982+
expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims)
983+
self.assertEqual(result, expected)
984+
985+
def run_test_skipped_elements(a_shape, dims):
986+
# check for inputs with skipped elements
987+
a = torch.randn(a_shape, dtype=dtype, device=device)
988+
a = a[::2]
989+
self.assertFalse(a.is_contiguous())
990+
b = torch.randn(a_shape[:2], dtype=dtype, device=device)
991+
b = b[::2]
992+
self.assertFalse(b.is_contiguous())
993+
result = torch.linalg.tensorsolve(a, b, dims=dims)
994+
expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims)
995+
self.assertEqual(result, expected)
996+
997+
# check non-contiguous out
998+
out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2]
999+
self.assertFalse(out.is_contiguous())
1000+
ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out)
1001+
self.assertEqual(ans, out)
1002+
self.assertEqual(ans, result)
1003+
1004+
a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
1005+
dims = [None, (0, 2)]
1006+
for a_shape, d in itertools.product(a_shapes, dims):
1007+
run_test_permuted(a_shape, d)
1008+
1009+
a_shapes = [(4, 3, 6), (6, 4, 4, 3)]
1010+
dims = [None, (0, 2)]
1011+
for a_shape, d in itertools.product(a_shapes, dims):
1012+
run_test_skipped_elements(a_shape, d)
1013+
1014+
@skipCUDAIfNoMagma
1015+
@skipCPUIfNoLapack
1016+
@dtypes(torch.float32)
1017+
def test_tensorsolve_errors_and_warnings(self, device, dtype):
1018+
# tensorsolve expects the input that can be reshaped to a square matrix
1019+
a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4))
1020+
b = torch.randn(8, 4)
1021+
self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape))
1022+
with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'):
1023+
torch.linalg.tensorsolve(a, b)
1024+
1025+
# if non-empty out tensor with wrong shape is passed a warning is given
1026+
out = torch.empty_like(a)
1027+
b = torch.randn(6, 4)
1028+
with warnings.catch_warnings(record=True) as w:
1029+
# Trigger warning
1030+
torch.linalg.tensorsolve(a, b, out=out)
1031+
# Check warning occurs
1032+
self.assertEqual(len(w), 1)
1033+
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1034+
1035+
# dtypes should match
1036+
out = torch.empty_like(a).to(torch.int)
1037+
with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"):
1038+
torch.linalg.tensorsolve(a, b, out=out)
9171039

9181040
instantiate_device_type_tests(TestLinalg, globals())
9191041

torch/csrc/api/include/torch/linalg.h

+26
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, opt
2828
return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
2929
}
3030

31+
inline Tensor tensorsolve(const Tensor& self, const Tensor& other, optional<IntArrayRef> dims) {
32+
return torch::linalg_tensorsolve(self, other, dims);
33+
}
34+
35+
inline Tensor& tensorsolve_out(Tensor& result, const Tensor& self, const Tensor& other, optional<IntArrayRef> dims) {
36+
return torch::linalg_tensorsolve_out(result, self, other, dims);
37+
}
38+
3139
} // namespace detail
3240
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
3341

@@ -53,4 +61,22 @@ inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string o
5361
return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
5462
}
5563

64+
/// Computes a tensor `x` such that `tensordot(input, x, dims=x.dim()) = other`.
65+
///
66+
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.tensorsolve
67+
///
68+
/// Example:
69+
/// ```
70+
/// auto a = torch::eye(2*3*4).reshape({2*3, 4, 2, 3, 4});
71+
/// auto b = torch::randn(2*3, 4);
72+
/// auto x = torch::linalg::tensorsolve(a, b);
73+
/// ```
74+
inline Tensor tensorsolve(const Tensor& input, const Tensor& other, optional<IntArrayRef> dims) {
75+
return detail::tensorsolve(input, other, dims);
76+
}
77+
78+
inline Tensor& tensorsolve_out(Tensor& result, const Tensor& input, const Tensor& other, optional<IntArrayRef> dims) {
79+
return detail::tensorsolve_out(result, input, other, dims);
80+
}
81+
5682
}} // torch::linalg

torch/linalg/__init__.py

+45
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,48 @@
139139
>>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :])
140140
(tensor(3.7417), tensor(11.2250))
141141
""")
142+
143+
tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r"""
144+
linalg.tensorsolve(input, other, dims=None, *, out=None) -> Tensor
145+
146+
Computes a tensor ``x`` such that ``tensordot(input, x, dims=x.ndim) = other``.
147+
The resulting tensor ``x`` has the same shape as ``input[other.ndim:]``.
148+
149+
Supports real-valued and, only on the CPU, complex-valued inputs.
150+
151+
.. note:: If :attr:`input` does not satisfy the requirement
152+
``prod(input.shape[other.ndim:]) == prod(input.shape[:other.ndim])``
153+
after (optionally) moving the dimensions using :attr:`dims`, then a RuntimeError will be thrown.
154+
155+
Args:
156+
input (Tensor): "left-hand-side" tensor, it must satisfy the requirement
157+
``prod(input.shape[other.ndim:]) == prod(input.shape[:other.ndim])``.
158+
other (Tensor): "right-hand-side" tensor of shape ``input.shape[other.ndim]``.
159+
dims (Tuple[int]): dimensions of :attr:`input` to be moved before the computation.
160+
Equivalent to calling ``input = movedim(input, dims, range(len(dims) - input.ndim, 0))``.
161+
If None (default), no dimensions are moved.
162+
163+
Keyword args:
164+
out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None``
165+
166+
Examples::
167+
168+
>>> a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4))
169+
>>> b = torch.randn(2 * 3, 4)
170+
>>> x = torch.linalg.tensorsolve(a, b)
171+
>>> x.shape
172+
torch.Size([2, 3, 4])
173+
>>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b)
174+
True
175+
176+
>>> a = torch.randn(6, 4, 4, 3, 2)
177+
>>> b = torch.randn(4, 3, 2)
178+
>>> x = torch.linalg.tensorsolve(a, b, dims=(0, 2))
179+
>>> x.shape
180+
torch.Size([6, 4])
181+
>>> a = a.permute(1, 3, 4, 0, 2)
182+
>>> a.shape[b.ndim:]
183+
torch.Size([6, 4])
184+
>>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b, atol=1e-6)
185+
True
186+
""")

torch/overrides.py

+1
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
738738
torch.tan: lambda input, out=None: -1,
739739
torch.tanh: lambda input, out=None: -1,
740740
torch.tensordot: lambda a, b, dims=2: -1,
741+
torch.linalg.tensorsolve: lambda a, b, dims=None: -1,
741742
torch.tensor_split: lambda input, indices_or_sections, dim=0: -1,
742743
torch.threshold: lambda input, threshold, value, inplace=False: -1,
743744
torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1,

0 commit comments

Comments
 (0)