Skip to content

Commit d70f155

Browse files
committed
update rambo test
1 parent 69d06bd commit d70f155

File tree

1 file changed

+193
-58
lines changed

1 file changed

+193
-58
lines changed

examples/rambo.py

Lines changed: 193 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,222 @@
1-
import sharpy as sp
2-
import numpy as np
1+
"""
32
4-
def sp_rambo(sp_C1, sp_F1, sp_Q1, sp_output, C1, F1, Q1, output):
3+
Examples:
4+
python rambo.py -nevts 10 -nout 10 -b sharpy -i 10000
55
6-
sp_C = 2.0 * sp_C1 - 1.0
7-
sp_S = sp.sqrt(1 - sp.square(sp_C))
8-
sp_F = 2.0 * sp.pi * sp_F1
9-
sp_Q = -sp.log(sp_Q1)
6+
"""
7+
import argparse
8+
import os
9+
import time as time_mod
10+
from functools import partial
1011

11-
sp_output[:, :, 0] = sp.reshape(sp_Q, (10, 10, 1))
12-
sp_output[:, :, 1] = sp.reshape(sp_Q * sp_S * sp.sin(sp_F), (10, 10, 1))
13-
sp_output[:, :, 2] = sp.reshape(sp_Q * sp_S * sp.cos(sp_F), (10, 10, 1))
14-
sp_output[:, :, 3] = sp.reshape(sp_Q * sp_C, (10, 10, 1))
12+
import numpy
13+
import sharpy
1514

16-
C = 2.0 * C1 - 1.0
17-
S = np.sqrt(1 - np.square(C))
18-
F = 2.0 * np.pi * F1
19-
Q = -np.log(Q1)
15+
try:
16+
import mpi4py
2017

21-
output[:, :, 0] = Q
22-
output[:, :, 1] = Q * S * np.sin(F)
23-
output[:, :, 2] = Q * S * np.cos(F)
24-
output[:, :, 3] = Q * C
25-
26-
sp.sync()
18+
mpi4py.rc.finalize = False
19+
from mpi4py import MPI
20+
21+
comm_rank = MPI.COMM_WORLD.Get_rank()
22+
comm = MPI.COMM_WORLD
23+
except ImportError:
24+
comm_rank = 0
25+
comm = None
2726

27+
def info(s):
28+
if comm_rank == 0:
29+
print(s)
2830

29-
def sp_initialize(nevts, nout, seed, types_dict):
30-
dtype = types_dict["float"]
31+
def naive_erf(x):
32+
"""
33+
Error function (erf) implementation
3134
32-
sp.random.seed(seed)
33-
C1 = sp.random.rand(nevts, nout)
34-
F1 = sp.random.rand(nevts, nout)
35-
Q1 = sp.random.rand(nevts, nout) * sp.random.rand(nevts, nout)
35+
Adapted from formula 7.1.26 in
36+
Abramowitz and Stegun, "Handbook of Mathematical Functions", 1965.
37+
"""
38+
y = numpy.abs(x)
3639

37-
sp.sync()
40+
a1 = 0.254829592
41+
a2 = -0.284496736
42+
a3 = 1.421413741
43+
a4 = -1.453152027
44+
a5 = 1.061405429
45+
p = 0.3275911
3846

39-
return (C1, F1, Q1, sp.zeros((nevts, nout, 4), dtype))
47+
t = 1.0 / (1.0 + p * y)
48+
f = (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t
49+
return numpy.sign(x) * (1.0 - f * numpy.exp(-y * y))
4050

41-
def np_rambo(nevts, nout, C1, F1, Q1, output):
51+
def sp_rambo(sp, sp_C1, sp_F1, sp_Q1, sp_output, nevts, nout):
52+
sp_C = 2.0 * sp_C1 - 1.0
53+
sp_S = sp.sqrt(1 - sp.square(sp_C))
54+
sp_F = 2.0 * sp.pi * sp_F1
55+
sp_Q = -sp.log(sp_Q1)
56+
57+
sp_output[:, :, 0] = sp.reshape(sp_Q, (nevts, nout, 1))
58+
sp_output[:, :, 1] = sp.reshape(sp_Q * sp_S * sp.sin(sp_F), (nevts, nout, 1))
59+
sp_output[:, :, 2] = sp.reshape(sp_Q * sp_S * sp.cos(sp_F), (nevts, nout, 1))
60+
sp_output[:, :, 3] = sp.reshape(sp_Q * sp_C, (nevts, nout, 1))
61+
62+
sharpy.sync()
63+
64+
def np_rambo(np, C1, F1, Q1, output, nevts, nout):
4265
C = 2.0 * C1 - 1.0
4366
S = np.sqrt(1 - np.square(C))
4467
F = 2.0 * np.pi * F1
4568
Q = -np.log(Q1)
4669

47-
# numpy: could not broadcast input array from shape (10,10,1) into shape (10,10)
4870
output[:, :, 0] = Q
4971
output[:, :, 1] = Q * S * np.sin(F)
5072
output[:, :, 2] = Q * S * np.cos(F)
5173
output[:, :, 3] = Q * C
5274

53-
def np_initialize(nevts, nout, seed, types_dict):
54-
dtype = types_dict["float"]
55-
75+
def initialize(np, nevts, nout, seed, dtype):
5676
np.random.seed(seed)
5777
C1 = np.random.rand(nevts, nout)
5878
F1 = np.random.rand(nevts, nout)
5979
Q1 = np.random.rand(nevts, nout) * np.random.rand(nevts, nout)
60-
6180
return (C1, F1, Q1, np.zeros((nevts, nout, 4), dtype))
62-
63-
def np_run(nevts, nout, seed=42):
64-
types_dict = {
65-
"float": sp.float64
66-
}
67-
sp_C1, sp_F1, sp_Q1, sp_output = sp_initialize(nevts, nout, seed, types_dict)
68-
# sp_rambo(nevts, nout, sp_C1, sp_F1, sp_Q1, sp_output)
69-
70-
types_dict = {
71-
"float": np.float64
72-
}
73-
np_C1, np_F1, np_Q1, np_output = np_initialize(nevts, nout, seed, types_dict)
74-
75-
# assert np.allclose(sp.to_numpy(sp_C1), np_C1)
76-
# assert np.allclose(sp.to_numpy(sp_F1), np_F1)
77-
# assert np.allclose(sp.to_numpy(sp_Q1), np_Q1)
78-
79-
sp_rambo(sp_C1, sp_F1, sp_Q1, sp_output, np_C1, np_F1, np_Q1, np_output)
8081

81-
assert np.allclose(sp.to_numpy(sp_output), np_output)
82+
def run(nevts, nout, backend, iterations, datatype):
83+
if backend == "sharpy":
84+
import sharpy as np
85+
from sharpy import fini, init, sync
86+
87+
device = os.getenv("SHARPY_DEVICE", "")
88+
create_full = partial(np.full, device=device)
89+
random_rand = partial(np.random.rand, device=device)
90+
erf = np.erf
91+
rambo = sp_rambo
92+
93+
init(False)
94+
elif backend == "numpy":
95+
import numpy as np
96+
97+
if comm is not None:
98+
assert (
99+
comm.Get_size() == 1
100+
), "Numpy backend only supports serial execution."
101+
102+
create_full = np.full
103+
random_rand = np.random.rand
104+
fini = sync = lambda x=None: None
105+
erf = naive_erf
106+
rambo = np_rambo
107+
else:
108+
raise ValueError(f'Unknown backend: "{backend}"')
109+
110+
dtype = {
111+
"f32": np.float32,
112+
"f64": np.float64,
113+
}[datatype]
114+
115+
info(f"Using backend: {backend}")
116+
info(f"Number of events: {nevts}")
117+
info(f"Number of outputs: {nout}")
118+
info(f"Datatype: {datatype}")
119+
120+
seed = 7777
121+
C1, F1, Q1, output = initialize(np, nevts, nout, seed, dtype)
122+
sync()
123+
124+
# verify
125+
if backend == "sharpy":
126+
sp_rambo(sharpy, C1, F1, Q1, output, nevts, nout)
127+
# sync() !! not work here?
128+
np_C1 = sharpy.to_numpy(C1)
129+
np_F1 = sharpy.to_numpy(F1)
130+
np_Q1 = sharpy.to_numpy(Q1)
131+
np_output = numpy.zeros((nevts, nout, 4))
132+
np_rambo(numpy, np_C1, np_F1, np_Q1, np_output, nevts, nout)
133+
assert numpy.allclose(sharpy.to_numpy(output), np_output)
134+
135+
def eval():
136+
tic = time_mod.perf_counter()
137+
rambo(np, C1, F1, Q1, output, nevts, nout)
138+
toc = time_mod.perf_counter()
139+
return toc - tic
140+
141+
# warm-up run
142+
t_warm = eval()
143+
144+
# evaluate
145+
info(f"Running {iterations} iterations")
146+
time_list = []
147+
for i in range(iterations):
148+
time_list.append(eval())
149+
150+
# get max time over mpi ranks
151+
if comm is not None:
152+
t_warm = comm.allreduce(t_warm, MPI.MAX)
153+
time_list = comm.allreduce(time_list, MPI.MAX)
154+
155+
t_min = numpy.min(time_list)
156+
t_max = numpy.max(time_list)
157+
t_med = numpy.median(time_list)
158+
# perf_rate = nopt / t_med / 1e6 # million options per second
159+
init_overhead = t_warm - t_med
160+
if backend == "sharpy":
161+
info(f"Estimated initialization overhead: {init_overhead:.5f} s")
162+
info(f"Min. duration: {t_min:.5f} s")
163+
info(f"Max. duration: {t_max:.5f} s")
164+
info(f"Median duration: {t_med:.5f} s")
165+
# info(f"Median rate: {perf_rate:.5f} Mopts/s")
166+
167+
fini()
82168

83169
if __name__ == "__main__":
84-
sp.init(False)
85-
nevts, nout = 10, 10
86-
np_run(nevts, nout)
87-
sp.fini()
170+
parser = argparse.ArgumentParser(
171+
description="Run rambo benchmark",
172+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
173+
)
174+
175+
parser.add_argument(
176+
"-nevts",
177+
"--num_events",
178+
type=int,
179+
default=10,
180+
help="Number of events to evaluate.",
181+
)
182+
parser.add_argument(
183+
"-nout",
184+
"--num_outputs",
185+
type=int,
186+
default=10,
187+
help="Number of outputs to evaluate.",
188+
)
189+
190+
parser.add_argument(
191+
"-b",
192+
"--backend",
193+
type=str,
194+
default="sharpy",
195+
choices=["sharpy", "numpy"],
196+
help="Backend to use.",
197+
)
198+
199+
parser.add_argument(
200+
"-i",
201+
"--iterations",
202+
type=int,
203+
default=10,
204+
help="Number of iterations to run.",
205+
)
206+
parser.add_argument(
207+
"-d",
208+
"--datatype",
209+
type=str,
210+
default="f64",
211+
choices=["f32", "f64"],
212+
help="Datatype for model state variables",
213+
)
214+
args = parser.parse_args()
215+
nevts, nout = args.num_events, args.num_outputs
216+
run(
217+
nevts,
218+
nout,
219+
args.backend,
220+
args.iterations,
221+
args.datatype,
222+
)

0 commit comments

Comments
 (0)