Skip to content

Commit 9c68f81

Browse files
Add solver for triangular systems (#1504)
* solve_triangular from previous pr #1236 * added test for type and value errors --------- Co-authored-by: Claudia Comito <[email protected]>
1 parent 9296f05 commit 9c68f81

File tree

2 files changed

+323
-1
lines changed

2 files changed

+323
-1
lines changed

Diff for: heat/core/linalg/solver.py

+192-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212

13-
__all__ = ["cg", "lanczos"]
13+
__all__ = ["cg", "lanczos", "solve_triangular"]
1414

1515

1616
def cg(A: DNDarray, b: DNDarray, x0: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
@@ -270,3 +270,194 @@ def lanczos(
270270
V.resplit_(axis=None)
271271

272272
return V, T
273+
274+
275+
def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray:
276+
"""
277+
This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns `x` in `Ax = b`, where `A` is a (possibly batched) upper triangular matrix and
278+
`b` a (possibly batched) vector or matrix of suitable shape, both provided as input to the function.
279+
The implementation builts on the corresponding solver in PyTorch and implements an memory-distributed, MPI-parallel block-wise version thereof.
280+
Parameters
281+
----------
282+
A : DNDarray
283+
An upper triangular invertible square (n x n) matrix or a batch thereof, i.e. a ``DNDarray`` of shape `(..., n, n)`.
284+
b : DNDarray
285+
a (possibly batched) n x k matrix, i.e. an DNDarray of shape (..., n, k), where the batch-dimensions denoted by ... need to coincide with those of A.
286+
(Batched) Vectors have to be provided as ... x n x 1 matrices and the split dimension of b must the second last dimension if not None.
287+
Note
288+
---------
289+
Since such a check might be computationally expensive, we do not check whether A is indeed upper triangular.
290+
If you require such a check, please open an issue on our GitHub page and request this feature.
291+
"""
292+
if not isinstance(A, DNDarray) or not isinstance(b, DNDarray):
293+
raise TypeError(f"Arguments need to be of type DNDarray, got {type(A)}, {type(b)}.")
294+
if not A.ndim >= 2:
295+
raise ValueError("A needs to be a (batched) matrix.")
296+
if not b.ndim == A.ndim:
297+
raise ValueError("b needs to have the same number of (batch) dimensions as A.")
298+
if not A.shape[-2] == A.shape[-1]:
299+
raise ValueError("A needs to be a (batched) square matrix.")
300+
301+
batch_dim = A.ndim - 2
302+
batch_shape = A.shape[:batch_dim]
303+
304+
if not A.shape[:batch_dim] == b.shape[:batch_dim]:
305+
raise ValueError("Batch dimensions of A and b must be of the same shape.")
306+
if b.split == batch_dim + 1:
307+
raise ValueError("split=1 is not allowed for the right hand side.")
308+
if not b.shape[batch_dim] == A.shape[-1]:
309+
raise ValueError("Dimension mismatch of A and b.")
310+
311+
if (
312+
A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim
313+
): # batch split
314+
if A.split != b.split:
315+
raise ValueError(
316+
"If a split dimension is a batch dimension, A and b must have the same split dimension. A possible solution would be a resplit of A or b to the same split dimension."
317+
)
318+
else:
319+
if (
320+
A.split is not None and b.split is not None
321+
): # both la dimensions split --> b.split = batch_dim
322+
# TODO remove?
323+
if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]):
324+
raise RuntimeError(
325+
"The process-local arrays of A and b have different sizes along the splitted axis. This is most likely due to one of the DNDarrays being in unbalanced state. \n Consider using `A.is_balanced(force_check=True)` and `b.is_balanced(force_check=True)` to check if A and b are balanced; \n then call `A.balance_()` and/or `b.balance_()` in order to achieve equal local shapes along the split axis before applying `solve_triangular`."
326+
)
327+
328+
comm = A.comm
329+
dev = A.device
330+
tdev = dev.torch_device
331+
332+
nprocs = comm.Get_size()
333+
334+
if A.split is None: # A not split
335+
if b.split is None:
336+
x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)
337+
338+
return factories.array(x, dtype=b.dtype, device=dev, comm=comm)
339+
else: # A not split, b.split == -2
340+
b_lshapes_cum = torch.hstack(
341+
[
342+
torch.zeros(1, dtype=torch.int32, device=tdev),
343+
torch.cumsum(b.lshape_map[:, -2], 0),
344+
]
345+
)
346+
347+
btilde_loc = b.larray.clone()
348+
A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]]
349+
350+
x = factories.zeros_like(b, device=dev, comm=comm)
351+
352+
for i in range(nprocs - 1, 0, -1):
353+
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
354+
displ = b_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
355+
count[i:] = 0 # nothing to send, as there are only zero rows
356+
displ[i:] = 0
357+
358+
res_send = torch.empty(0)
359+
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)
360+
361+
if comm.rank == i:
362+
x.larray = torch.linalg.solve_triangular(
363+
A_loc[..., b_lshapes_cum[i] : b_lshapes_cum[i + 1], :],
364+
btilde_loc,
365+
upper=True,
366+
)
367+
res_send = A_loc @ x.larray
368+
369+
comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)
370+
371+
if comm.rank < i:
372+
btilde_loc -= res_recv
373+
374+
if comm.rank == 0:
375+
x.larray = torch.linalg.solve_triangular(
376+
A_loc[..., : b_lshapes_cum[1], :], btilde_loc, upper=True
377+
)
378+
379+
return x
380+
381+
if A.split < batch_dim: # batch split
382+
x = factories.zeros_like(b, device=dev, comm=comm, split=A.split)
383+
x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)
384+
385+
return x
386+
387+
if A.split >= batch_dim: # both splits in la dims
388+
A_lshapes_cum = torch.hstack(
389+
[
390+
torch.zeros(1, dtype=torch.int32, device=tdev),
391+
torch.cumsum(A.lshape_map[:, A.split], 0),
392+
]
393+
)
394+
395+
if b.split is None:
396+
btilde_loc = b.larray[
397+
..., A_lshapes_cum[comm.rank] : A_lshapes_cum[comm.rank + 1], :
398+
].clone()
399+
else: # b is split at la dim 0
400+
btilde_loc = b.larray.clone()
401+
402+
x = factories.zeros_like(
403+
b, device=dev, comm=comm, split=batch_dim
404+
) # split at la dim 0 in case b is not split
405+
406+
if A.split == batch_dim + 1:
407+
for i in range(nprocs - 1, 0, -1):
408+
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
409+
displ = A_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
410+
count[i:] = 0 # nothing to send, as there are only zero rows
411+
displ[i:] = 0
412+
413+
res_send = torch.empty(0)
414+
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)
415+
416+
if comm.rank == i:
417+
x.larray = torch.linalg.solve_triangular(
418+
A.larray[..., A_lshapes_cum[i] : A_lshapes_cum[i + 1], :],
419+
btilde_loc,
420+
upper=True,
421+
)
422+
res_send = A.larray @ x.larray
423+
424+
comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)
425+
426+
if comm.rank < i:
427+
btilde_loc -= res_recv
428+
429+
if comm.rank == 0:
430+
x.larray = torch.linalg.solve_triangular(
431+
A.larray[..., : A_lshapes_cum[1], :], btilde_loc, upper=True
432+
)
433+
434+
else: # split dim is la dim 0
435+
for i in range(nprocs - 1, 0, -1):
436+
idims = tuple(x.lshape_map[i])
437+
if comm.rank == i:
438+
x.larray = torch.linalg.solve_triangular(
439+
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]],
440+
btilde_loc,
441+
upper=True,
442+
)
443+
x_from_i = x.larray
444+
else:
445+
x_from_i = torch.zeros(
446+
idims,
447+
dtype=b.dtype.torch_type(),
448+
device=tdev,
449+
)
450+
451+
comm.Bcast(x_from_i, root=i)
452+
453+
if comm.rank < i:
454+
btilde_loc -= (
455+
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i
456+
)
457+
458+
if comm.rank == 0:
459+
x.larray = torch.linalg.solve_triangular(
460+
A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True
461+
)
462+
463+
return x

Diff for: heat/core/linalg/tests/test_solver.py

+131
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,134 @@ def test_lanczos(self):
135135
with self.assertRaises(NotImplementedError):
136136
A = ht.random.randn(10, 10, split=1)
137137
V, T = ht.lanczos(A, m=3)
138+
139+
def test_solve_triangular(self):
140+
torch.manual_seed(42)
141+
tdev = ht.get_device().torch_device
142+
143+
# non-batched tests
144+
k = 100 # data dimension size
145+
146+
# random triangular matrix inversion
147+
at = torch.rand((k, k))
148+
# at += torch.eye(k)
149+
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
150+
at = torch.triu(at).to(tdev)
151+
152+
ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True)
153+
154+
a = ht.factories.asarray(at, copy=True)
155+
c = ht.factories.asarray(ct, copy=True)
156+
b = ht.eye(k)
157+
158+
# exceptions
159+
with self.assertRaises(TypeError): # invalid datatype for b
160+
ht.linalg.solve_triangular(a, 42)
161+
162+
with self.assertRaises(ValueError): # a no matrix, not enough dimensions
163+
ht.linalg.solve_triangular(a[1], b)
164+
165+
with self.assertRaises(ValueError): # a and b different number of dimensions
166+
ht.linalg.solve_triangular(a, b[1])
167+
168+
with self.assertRaises(ValueError): # a no square matrix
169+
ht.linalg.solve_triangular(a[1:, ...], b)
170+
171+
with self.assertRaises(ValueError): # split=1 for b
172+
b.resplit_(-1)
173+
ht.linalg.solve_triangular(a, b)
174+
175+
b.resplit_(0)
176+
with self.assertRaises(ValueError): # dimension mismatch
177+
ht.linalg.solve_triangular(a, b[1:, ...])
178+
179+
for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
180+
a.resplit_(s0)
181+
b.resplit_(s1)
182+
183+
res = ht.linalg.solve_triangular(a, b)
184+
self.assertTrue(ht.allclose(res, c))
185+
186+
# triangular ones inversion
187+
# for this test case, the results should be exact
188+
at = torch.triu(torch.ones_like(at)).to(tdev)
189+
ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True)
190+
191+
a = ht.factories.asarray(at, copy=True)
192+
c = ht.factories.asarray(ct, copy=True)
193+
194+
for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
195+
a.resplit_(s0)
196+
b.resplit_(s1)
197+
198+
res = ht.linalg.solve_triangular(a, b)
199+
self.assertTrue(ht.equal(res, c))
200+
201+
# batched tests
202+
batch_shapes = [
203+
(10,),
204+
(
205+
4,
206+
4,
207+
4,
208+
20,
209+
),
210+
]
211+
m = 100 # data dimension size
212+
213+
# exceptions
214+
batch_shape = batch_shapes[1]
215+
216+
at = torch.rand((*batch_shape, m, m))
217+
# at += torch.eye(k)
218+
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
219+
at = torch.triu(at).to(tdev)
220+
bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev)
221+
222+
ct = torch.linalg.solve_triangular(at, bt, upper=True)
223+
224+
a = ht.factories.asarray(at, copy=True)
225+
c = ht.factories.asarray(ct, copy=True)
226+
b = ht.factories.asarray(bt, copy=True)
227+
228+
with self.assertRaises(ValueError): # batch dimensions of different shapes
229+
ht.linalg.solve_triangular(a[1:, ...], b)
230+
231+
with self.assertRaises(ValueError): # different batched split dimensions
232+
a.resplit_(0)
233+
b.resplit_(1)
234+
ht.linalg.solve_triangular(a, b)
235+
236+
for batch_shape in batch_shapes:
237+
# batch_shape = tuple() # no batch dimensions
238+
239+
at = torch.rand((*batch_shape, m, m))
240+
# at += torch.eye(k)
241+
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
242+
at = torch.triu(at).to(tdev)
243+
bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev)
244+
245+
ct = torch.linalg.solve_triangular(at, bt, upper=True)
246+
247+
a = ht.factories.asarray(at, copy=True)
248+
c = ht.factories.asarray(ct, copy=True)
249+
b = ht.factories.asarray(bt, copy=True)
250+
251+
# split in linalg dimension or none
252+
for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
253+
a.resplit_(s0)
254+
b.resplit_(s1)
255+
256+
res = ht.linalg.solve_triangular(a, b)
257+
258+
self.assertTrue(ht.allclose(c, res))
259+
260+
# split in batch dimension
261+
s = len(batch_shape) - 1
262+
a.resplit_(s)
263+
b.resplit_(s)
264+
c.resplit_(s)
265+
266+
res = ht.linalg.solve_triangular(a, b)
267+
268+
self.assertTrue(ht.allclose(c, res))

0 commit comments

Comments
 (0)