-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathlinalg.py
95 lines (68 loc) · 1.76 KB
/
linalg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# flake8: noqa
import heat as ht
from mpi4py import MPI
from perun import monitor
@monitor()
def matmul_split_0(a, b):
a @ b
@monitor()
def matmul_split_1(a, b):
a @ b
@monitor()
def qr_split_0(a):
qr = ht.linalg.qr(a)
@monitor()
def qr_split_1(a):
qr = ht.linalg.qr(a)
@monitor()
def hierachical_svd_rank(data, r):
approx_svd = ht.linalg.hsvd_rank(data, maxrank=r, compute_sv=True, silent=True)
@monitor()
def hierachical_svd_tol(data, tol):
approx_svd = ht.linalg.hsvd_rtol(data, rtol=tol, compute_sv=True, silent=True)
@monitor()
def lanczos(B):
V, T = ht.lanczos(B, m=B.shape[0])
@monitor()
def zolopd_split0(A):
U, H = ht.linalg.polar(A)
@monitor()
def zolopd_split1(A):
U, H = ht.linalg.polar(A)
def run_linalg_benchmarks():
n = 3000
a = ht.random.random((n, n), split=0)
b = ht.random.random((n, n), split=0)
matmul_split_0(a, b)
del a, b
a = ht.random.random((n, n), split=1)
b = ht.random.random((n, n), split=1)
matmul_split_1(a, b)
del a, b
n = int((4000000 // MPI.COMM_WORLD.size) ** 0.5)
m = MPI.COMM_WORLD.size * n
a_0 = ht.random.random((m, n), split=0)
qr_split_0(a_0)
del a_0
n = 2000
a_1 = ht.random.random((n, n), split=1)
qr_split_1(a_1)
del a_1
n = 50
A = ht.random.random((n, n), dtype=ht.float64, split=0)
B = A @ A.T
lanczos(B)
del A, B
data = ht.utils.data.matrixgallery.random_known_rank(
1000, 500 * MPI.COMM_WORLD.Get_size(), 10, split=1, dtype=ht.float32
)[0]
hierachical_svd_rank(data, 10)
hierachical_svd_tol(data, 1e-2)
del data
n = 1000
A = ht.random.random((n, n), split=0)
zolopd_split0(A)
del A
A = ht.random.random((n, n), split=1)
zolopd_split1(A)
del A