Skip to content

Commit 59e0a04

Browse files
alexfiklinducer
authored andcommitted
skeletonization: fix crash with low number of proxies
Managed to hit a case where the `src_mat` had shape `(21, 23)` and the `tgt_mat` had shape `(23, 22)`. The rank after skeletonizing the target was `22`, which was smaller then the max one for the source and caused `interp_decomp` to crash somewhere in the Fortran code.
1 parent 488b958 commit 59e0a04

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

pytential/linalg/skeletonization.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,7 @@ def _skeletonize_block_by_proxy_with_mats(
616616
k = id_rank
617617
src_mat = np.vstack(src_result[i])
618618
tgt_mat = np.hstack(tgt_result[i])
619+
max_allowable_rank = min(*src_mat.shape, *tgt_mat.shape)
619620

620621
if __debug__:
621622
isfinite = np.isfinite(tgt_mat)
@@ -625,21 +626,26 @@ def _skeletonize_block_by_proxy_with_mats(
625626

626627
# skeletonize target points
627628
k, idx, interp = interp_decomp(tgt_mat.T, rank=k, eps=id_eps)
628-
assert k > 0
629+
assert 0 < k <= len(idx)
630+
631+
if k > max_allowable_rank:
632+
k = max_allowable_rank
633+
interp = interp[:k, :]
629634

630635
L[i] = interp.T
631636
tgt_skl_indices[i] = tgt_src_index.targets.cluster_indices(i)[idx[:k]]
637+
assert L[i].shape == (tgt_mat.shape[0], k)
632638

633639
# skeletonize source points
634640
k, idx, interp = interp_decomp(src_mat, rank=k, eps=None)
635-
assert k > 0
641+
assert 0 < k <= len(idx)
636642

637643
R[i] = interp
638644
src_skl_indices[i] = tgt_src_index.sources.cluster_indices(i)[idx[:k]]
645+
assert R[i].shape == (k, src_mat.shape[1])
639646

640647
skel_starts[i + 1] = skel_starts[i] + k
641-
assert R[i].shape == (k, src_mat.shape[1])
642-
assert L[i].shape == (tgt_mat.shape[0], k)
648+
assert tgt_skl_indices[i].shape == src_skl_indices[i].shape
643649

644650
from pytential.linalg import make_index_list
645651
src_skl_index = make_index_list(np.hstack(src_skl_indices), skel_starts)

0 commit comments

Comments
 (0)