|
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 |
|
13 |
| -__all__ = ["cg", "lanczos"] |
| 13 | +__all__ = ["cg", "lanczos", "solve_triangular"] |
14 | 14 |
|
15 | 15 |
|
16 | 16 | def cg(A: DNDarray, b: DNDarray, x0: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
|
@@ -270,3 +270,194 @@ def lanczos(
|
270 | 270 | V.resplit_(axis=None)
|
271 | 271 |
|
272 | 272 | return V, T
|
| 273 | + |
| 274 | + |
| 275 | +def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: |
| 276 | + """ |
| 277 | + This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns `x` in `Ax = b`, where `A` is a (possibly batched) upper triangular matrix and |
| 278 | + `b` a (possibly batched) vector or matrix of suitable shape, both provided as input to the function. |
| 279 | + The implementation builts on the corresponding solver in PyTorch and implements an memory-distributed, MPI-parallel block-wise version thereof. |
| 280 | + Parameters |
| 281 | + ---------- |
| 282 | + A : DNDarray |
| 283 | + An upper triangular invertible square (n x n) matrix or a batch thereof, i.e. a ``DNDarray`` of shape `(..., n, n)`. |
| 284 | + b : DNDarray |
| 285 | + a (possibly batched) n x k matrix, i.e. an DNDarray of shape (..., n, k), where the batch-dimensions denoted by ... need to coincide with those of A. |
| 286 | + (Batched) Vectors have to be provided as ... x n x 1 matrices and the split dimension of b must the second last dimension if not None. |
| 287 | + Note |
| 288 | + --------- |
| 289 | + Since such a check might be computationally expensive, we do not check whether A is indeed upper triangular. |
| 290 | + If you require such a check, please open an issue on our GitHub page and request this feature. |
| 291 | + """ |
| 292 | + if not isinstance(A, DNDarray) or not isinstance(b, DNDarray): |
| 293 | + raise TypeError(f"Arguments need to be of type DNDarray, got {type(A)}, {type(b)}.") |
| 294 | + if not A.ndim >= 2: |
| 295 | + raise ValueError("A needs to be a (batched) matrix.") |
| 296 | + if not b.ndim == A.ndim: |
| 297 | + raise ValueError("b needs to have the same number of (batch) dimensions as A.") |
| 298 | + if not A.shape[-2] == A.shape[-1]: |
| 299 | + raise ValueError("A needs to be a (batched) square matrix.") |
| 300 | + |
| 301 | + batch_dim = A.ndim - 2 |
| 302 | + batch_shape = A.shape[:batch_dim] |
| 303 | + |
| 304 | + if not A.shape[:batch_dim] == b.shape[:batch_dim]: |
| 305 | + raise ValueError("Batch dimensions of A and b must be of the same shape.") |
| 306 | + if b.split == batch_dim + 1: |
| 307 | + raise ValueError("split=1 is not allowed for the right hand side.") |
| 308 | + if not b.shape[batch_dim] == A.shape[-1]: |
| 309 | + raise ValueError("Dimension mismatch of A and b.") |
| 310 | + |
| 311 | + if ( |
| 312 | + A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim |
| 313 | + ): # batch split |
| 314 | + if A.split != b.split: |
| 315 | + raise ValueError( |
| 316 | + "If a split dimension is a batch dimension, A and b must have the same split dimension. A possible solution would be a resplit of A or b to the same split dimension." |
| 317 | + ) |
| 318 | + else: |
| 319 | + if ( |
| 320 | + A.split is not None and b.split is not None |
| 321 | + ): # both la dimensions split --> b.split = batch_dim |
| 322 | + # TODO remove? |
| 323 | + if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]): |
| 324 | + raise RuntimeError( |
| 325 | + "The process-local arrays of A and b have different sizes along the splitted axis. This is most likely due to one of the DNDarrays being in unbalanced state. \n Consider using `A.is_balanced(force_check=True)` and `b.is_balanced(force_check=True)` to check if A and b are balanced; \n then call `A.balance_()` and/or `b.balance_()` in order to achieve equal local shapes along the split axis before applying `solve_triangular`." |
| 326 | + ) |
| 327 | + |
| 328 | + comm = A.comm |
| 329 | + dev = A.device |
| 330 | + tdev = dev.torch_device |
| 331 | + |
| 332 | + nprocs = comm.Get_size() |
| 333 | + |
| 334 | + if A.split is None: # A not split |
| 335 | + if b.split is None: |
| 336 | + x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) |
| 337 | + |
| 338 | + return factories.array(x, dtype=b.dtype, device=dev, comm=comm) |
| 339 | + else: # A not split, b.split == -2 |
| 340 | + b_lshapes_cum = torch.hstack( |
| 341 | + [ |
| 342 | + torch.zeros(1, dtype=torch.int32, device=tdev), |
| 343 | + torch.cumsum(b.lshape_map[:, -2], 0), |
| 344 | + ] |
| 345 | + ) |
| 346 | + |
| 347 | + btilde_loc = b.larray.clone() |
| 348 | + A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]] |
| 349 | + |
| 350 | + x = factories.zeros_like(b, device=dev, comm=comm) |
| 351 | + |
| 352 | + for i in range(nprocs - 1, 0, -1): |
| 353 | + count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy() |
| 354 | + displ = b_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy() |
| 355 | + count[i:] = 0 # nothing to send, as there are only zero rows |
| 356 | + displ[i:] = 0 |
| 357 | + |
| 358 | + res_send = torch.empty(0) |
| 359 | + res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev) |
| 360 | + |
| 361 | + if comm.rank == i: |
| 362 | + x.larray = torch.linalg.solve_triangular( |
| 363 | + A_loc[..., b_lshapes_cum[i] : b_lshapes_cum[i + 1], :], |
| 364 | + btilde_loc, |
| 365 | + upper=True, |
| 366 | + ) |
| 367 | + res_send = A_loc @ x.larray |
| 368 | + |
| 369 | + comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim) |
| 370 | + |
| 371 | + if comm.rank < i: |
| 372 | + btilde_loc -= res_recv |
| 373 | + |
| 374 | + if comm.rank == 0: |
| 375 | + x.larray = torch.linalg.solve_triangular( |
| 376 | + A_loc[..., : b_lshapes_cum[1], :], btilde_loc, upper=True |
| 377 | + ) |
| 378 | + |
| 379 | + return x |
| 380 | + |
| 381 | + if A.split < batch_dim: # batch split |
| 382 | + x = factories.zeros_like(b, device=dev, comm=comm, split=A.split) |
| 383 | + x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) |
| 384 | + |
| 385 | + return x |
| 386 | + |
| 387 | + if A.split >= batch_dim: # both splits in la dims |
| 388 | + A_lshapes_cum = torch.hstack( |
| 389 | + [ |
| 390 | + torch.zeros(1, dtype=torch.int32, device=tdev), |
| 391 | + torch.cumsum(A.lshape_map[:, A.split], 0), |
| 392 | + ] |
| 393 | + ) |
| 394 | + |
| 395 | + if b.split is None: |
| 396 | + btilde_loc = b.larray[ |
| 397 | + ..., A_lshapes_cum[comm.rank] : A_lshapes_cum[comm.rank + 1], : |
| 398 | + ].clone() |
| 399 | + else: # b is split at la dim 0 |
| 400 | + btilde_loc = b.larray.clone() |
| 401 | + |
| 402 | + x = factories.zeros_like( |
| 403 | + b, device=dev, comm=comm, split=batch_dim |
| 404 | + ) # split at la dim 0 in case b is not split |
| 405 | + |
| 406 | + if A.split == batch_dim + 1: |
| 407 | + for i in range(nprocs - 1, 0, -1): |
| 408 | + count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy() |
| 409 | + displ = A_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy() |
| 410 | + count[i:] = 0 # nothing to send, as there are only zero rows |
| 411 | + displ[i:] = 0 |
| 412 | + |
| 413 | + res_send = torch.empty(0) |
| 414 | + res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev) |
| 415 | + |
| 416 | + if comm.rank == i: |
| 417 | + x.larray = torch.linalg.solve_triangular( |
| 418 | + A.larray[..., A_lshapes_cum[i] : A_lshapes_cum[i + 1], :], |
| 419 | + btilde_loc, |
| 420 | + upper=True, |
| 421 | + ) |
| 422 | + res_send = A.larray @ x.larray |
| 423 | + |
| 424 | + comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim) |
| 425 | + |
| 426 | + if comm.rank < i: |
| 427 | + btilde_loc -= res_recv |
| 428 | + |
| 429 | + if comm.rank == 0: |
| 430 | + x.larray = torch.linalg.solve_triangular( |
| 431 | + A.larray[..., : A_lshapes_cum[1], :], btilde_loc, upper=True |
| 432 | + ) |
| 433 | + |
| 434 | + else: # split dim is la dim 0 |
| 435 | + for i in range(nprocs - 1, 0, -1): |
| 436 | + idims = tuple(x.lshape_map[i]) |
| 437 | + if comm.rank == i: |
| 438 | + x.larray = torch.linalg.solve_triangular( |
| 439 | + A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]], |
| 440 | + btilde_loc, |
| 441 | + upper=True, |
| 442 | + ) |
| 443 | + x_from_i = x.larray |
| 444 | + else: |
| 445 | + x_from_i = torch.zeros( |
| 446 | + idims, |
| 447 | + dtype=b.dtype.torch_type(), |
| 448 | + device=tdev, |
| 449 | + ) |
| 450 | + |
| 451 | + comm.Bcast(x_from_i, root=i) |
| 452 | + |
| 453 | + if comm.rank < i: |
| 454 | + btilde_loc -= ( |
| 455 | + A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i |
| 456 | + ) |
| 457 | + |
| 458 | + if comm.rank == 0: |
| 459 | + x.larray = torch.linalg.solve_triangular( |
| 460 | + A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True |
| 461 | + ) |
| 462 | + |
| 463 | + return x |
0 commit comments