Skip to content

Commit c4aa251

Browse files
committed
Refactor random module and add tests
1 parent 6456bfb commit c4aa251

File tree

8 files changed

+379
-21
lines changed

8 files changed

+379
-21
lines changed

examples/rambo.py

+227
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""
2+
3+
Examples:
4+
python rambo.py -nevts 10 -nout 10 -b sharpy -i 10000
5+
mpiexec -n 3 python rambo.py -nevts 64 -nout 64 -b sharpy -i 100
6+
7+
"""
8+
9+
import argparse
10+
import time as time_mod
11+
12+
import numpy
13+
14+
import sharpy
15+
16+
try:
17+
import mpi4py
18+
19+
mpi4py.rc.finalize = False
20+
from mpi4py import MPI
21+
22+
comm_rank = MPI.COMM_WORLD.Get_rank()
23+
comm = MPI.COMM_WORLD
24+
except ImportError:
25+
comm_rank = 0
26+
comm = None
27+
28+
29+
def info(s):
30+
if comm_rank == 0:
31+
print(s)
32+
33+
34+
def naive_erf(x):
35+
"""
36+
Error function (erf) implementation
37+
38+
Adapted from formula 7.1.26 in
39+
Abramowitz and Stegun, "Handbook of Mathematical Functions", 1965.
40+
"""
41+
y = numpy.abs(x)
42+
43+
a1 = 0.254829592
44+
a2 = -0.284496736
45+
a3 = 1.421413741
46+
a4 = -1.453152027
47+
a5 = 1.061405429
48+
p = 0.3275911
49+
50+
t = 1.0 / (1.0 + p * y)
51+
f = (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t
52+
return numpy.sign(x) * (1.0 - f * numpy.exp(-y * y))
53+
54+
55+
def sp_rambo(sp, sp_C1, sp_F1, sp_Q1, sp_output, nevts, nout):
56+
sp_C = 2.0 * sp_C1 - 1.0
57+
sp_S = sp.sqrt(1 - sp.square(sp_C))
58+
sp_F = 2.0 * sp.pi * sp_F1
59+
sp_Q = -sp.log(sp_Q1)
60+
61+
sp_output[:, :, 0] = sp.reshape(sp_Q, (nevts, nout, 1))
62+
sp_output[:, :, 1] = sp.reshape(
63+
sp_Q * sp_S * sp.sin(sp_F), (nevts, nout, 1)
64+
)
65+
sp_output[:, :, 2] = sp.reshape(
66+
sp_Q * sp_S * sp.cos(sp_F), (nevts, nout, 1)
67+
)
68+
sp_output[:, :, 3] = sp.reshape(sp_Q * sp_C, (nevts, nout, 1))
69+
70+
sharpy.sync()
71+
72+
73+
def np_rambo(np, C1, F1, Q1, output, nevts, nout):
74+
C = 2.0 * C1 - 1.0
75+
S = np.sqrt(1 - np.square(C))
76+
F = 2.0 * np.pi * F1
77+
Q = -np.log(Q1)
78+
79+
output[:, :, 0] = Q
80+
output[:, :, 1] = Q * S * np.sin(F)
81+
output[:, :, 2] = Q * S * np.cos(F)
82+
output[:, :, 3] = Q * C
83+
84+
85+
def initialize(np, nevts, nout, seed, dtype):
86+
np.random.seed(seed)
87+
C1 = np.random.rand(nevts, nout)
88+
F1 = np.random.rand(nevts, nout)
89+
Q1 = np.random.rand(nevts, nout) * np.random.rand(nevts, nout)
90+
return (C1, F1, Q1, np.zeros((nevts, nout, 4), dtype))
91+
92+
93+
def run(nevts, nout, backend, iterations, datatype):
94+
if backend == "sharpy":
95+
import sharpy as np
96+
from sharpy import fini, init, sync
97+
98+
rambo = sp_rambo
99+
100+
init(False)
101+
elif backend == "numpy":
102+
import numpy as np
103+
104+
if comm is not None:
105+
assert (
106+
comm.Get_size() == 1
107+
), "Numpy backend only supports serial execution."
108+
109+
fini = sync = lambda x=None: None
110+
rambo = np_rambo
111+
else:
112+
raise ValueError(f'Unknown backend: "{backend}"')
113+
114+
dtype = {
115+
"f32": np.float32,
116+
"f64": np.float64,
117+
}[datatype]
118+
119+
info(f"Using backend: {backend}")
120+
info(f"Number of events: {nevts}")
121+
info(f"Number of outputs: {nout}")
122+
info(f"Datatype: {datatype}")
123+
124+
seed = 7777
125+
C1, F1, Q1, output = initialize(np, nevts, nout, seed, dtype)
126+
sync()
127+
128+
# verify
129+
if backend == "sharpy":
130+
sp_rambo(sharpy, C1, F1, Q1, output, nevts, nout)
131+
# sync() !! not work here?
132+
np_C1 = sharpy.to_numpy(C1)
133+
np_F1 = sharpy.to_numpy(F1)
134+
np_Q1 = sharpy.to_numpy(Q1)
135+
np_output = numpy.zeros((nevts, nout, 4))
136+
np_rambo(numpy, np_C1, np_F1, np_Q1, np_output, nevts, nout)
137+
assert numpy.allclose(sharpy.to_numpy(output), np_output)
138+
139+
def eval():
140+
tic = time_mod.perf_counter()
141+
rambo(np, C1, F1, Q1, output, nevts, nout)
142+
toc = time_mod.perf_counter()
143+
return toc - tic
144+
145+
# warm-up run
146+
t_warm = eval()
147+
148+
# evaluate
149+
info(f"Running {iterations} iterations")
150+
time_list = []
151+
for i in range(iterations):
152+
time_list.append(eval())
153+
154+
# get max time over mpi ranks
155+
if comm is not None:
156+
t_warm = comm.allreduce(t_warm, MPI.MAX)
157+
time_list = comm.allreduce(time_list, MPI.MAX)
158+
159+
t_min = numpy.min(time_list)
160+
t_max = numpy.max(time_list)
161+
t_med = numpy.median(time_list)
162+
# perf_rate = nopt / t_med / 1e6 # million options per second
163+
init_overhead = t_warm - t_med
164+
if backend == "sharpy":
165+
info(f"Estimated initialization overhead: {init_overhead:.5f} s")
166+
info(f"Min. duration: {t_min:.5f} s")
167+
info(f"Max. duration: {t_max:.5f} s")
168+
info(f"Median duration: {t_med:.5f} s")
169+
# info(f"Median rate: {perf_rate:.5f} Mopts/s")
170+
171+
fini()
172+
173+
174+
if __name__ == "__main__":
175+
parser = argparse.ArgumentParser(
176+
description="Run rambo benchmark",
177+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
178+
)
179+
180+
parser.add_argument(
181+
"-nevts",
182+
"--num_events",
183+
type=int,
184+
default=10,
185+
help="Number of events to evaluate.",
186+
)
187+
parser.add_argument(
188+
"-nout",
189+
"--num_outputs",
190+
type=int,
191+
default=10,
192+
help="Number of outputs to evaluate.",
193+
)
194+
195+
parser.add_argument(
196+
"-b",
197+
"--backend",
198+
type=str,
199+
default="sharpy",
200+
choices=["sharpy", "numpy"],
201+
help="Backend to use.",
202+
)
203+
204+
parser.add_argument(
205+
"-i",
206+
"--iterations",
207+
type=int,
208+
default=10,
209+
help="Number of iterations to run.",
210+
)
211+
parser.add_argument(
212+
"-d",
213+
"--datatype",
214+
type=str,
215+
default="f64",
216+
choices=["f32", "f64"],
217+
help="Datatype for model state variables",
218+
)
219+
args = parser.parse_args()
220+
nevts, nout = args.num_events, args.num_outputs
221+
run(
222+
nevts,
223+
nout,
224+
args.backend,
225+
args.iterations,
226+
args.datatype,
227+
)

sharpy/__init__.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,22 @@
3838
from ._sharpy import sync
3939
from .ndarray import ndarray
4040

41+
42+
# Lazy load submodules
43+
def __getattr__(name):
44+
if name == "random":
45+
import sharpy.random as random
46+
47+
return random
48+
elif name == "numpy":
49+
import sharpy.numpy as numpy
50+
51+
return numpy
52+
53+
if "_fallback" in globals():
54+
return _fallback(name)
55+
56+
4157
_sharpy_cw = bool(int(getenv("SHARPY_CW", False)))
4258

4359
pi = 3.1415926535897932384626433
@@ -185,7 +201,3 @@ def __getattr__(self, name):
185201
dt.linalg.norm(...)
186202
"""
187203
return _fallback(name, self._func)
188-
189-
def __getattr__(name):
190-
"Attempt to find a fallback in fallback-lib"
191-
return _fallback(name)

sharpy/random.py

-11
This file was deleted.

sharpy/random/__init__.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
3+
import sharpy as sp
4+
from sharpy import float64
5+
from sharpy.numpy import fromfunction
6+
7+
8+
def uniform(low, high, size, device="", team=1):
9+
data = np.random.uniform(low, high, size)
10+
if len(data.shape) == 0:
11+
sp_data = sp.empty(())
12+
sp_data[()] = data[()]
13+
return sp_data
14+
return fromfunction(
15+
lambda *index: data[index],
16+
data.shape,
17+
dtype=float64,
18+
device=device,
19+
team=team,
20+
)
21+
22+
23+
def rand(*shape, device="", team=1):
24+
data = np.random.rand(*shape)
25+
if isinstance(data, float):
26+
return data
27+
return fromfunction(
28+
lambda *index: data[index],
29+
data.shape,
30+
dtype=float64,
31+
device=device,
32+
team=team,
33+
)
34+
35+
36+
def seed(s):
37+
np.random.seed(s)

src/Service.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ struct DeferredService : public DeferredT<Service::service_promise_type,
5151
// drop from dep manager
5252
dm.drop(_a);
5353
// and from registry
54-
Registry::del(_a);
54+
dm.addReady(_a, [this](id_type guid) {
55+
assert(this->_a == guid);
56+
Registry::del(guid);
57+
});
5558
break;
5659
}
5760
case RUN:

src/idtr.cpp

+40-5
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,21 @@ void unpack(void *in, SHARPY::DTypeId dtype, const int64_t *sizes,
268268
});
269269
}
270270

271+
/// copy contiguous block of data into a possibly strided array
272+
void unpack1(void *in, SHARPY::DTypeId dtype, const int64_t *sizes,
273+
const int64_t *strides, uint64_t ndim, void *out) {
274+
if (!in || !sizes || !strides || !out) {
275+
return;
276+
}
277+
dispatch(dtype, out, [sizes, strides, ndim, in](auto *out_) {
278+
auto in_ = static_cast<decltype(out_)>(in);
279+
SHARPY::forall(0, out_, sizes, strides, ndim, [&in_](auto *out) {
280+
*out = *in_;
281+
++in_;
282+
});
283+
});
284+
}
285+
271286
template <typename T>
272287
void copy_(uint64_t d, uint64_t &pos, T *cptr, const int64_t *sizes,
273288
const int64_t *strides, const uint64_t *chunks, uint64_t nd,
@@ -489,21 +504,41 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
489504
}
490505
}
491506

507+
int64_t oStride = std::accumulate(oDataStridesPtr, oDataStridesPtr + oNDims,
508+
1, std::multiplies<int64_t>());
509+
void *rBuff = oDataPtr;
510+
if (oStride != 1) {
511+
rBuff = new char[sizeof_dtype(sharpytype) * myOSz];
512+
}
513+
492514
SHARPY::Buffer sendbuff(totSSz * sizeof_dtype(sharpytype), 2);
493515
bufferizeN(iNDims, iDataPtr, iDataShapePtr, iDataStridesPtr, sharpytype, N,
494516
lsOffs.data(), lsEnds.data(), sendbuff.data());
495517
auto hdl = tc->alltoall(sendbuff.data(), sszs.data(), soffs.data(),
496-
sharpytype, oDataPtr, rszs.data(), roffs.data());
518+
sharpytype, rBuff, rszs.data(), roffs.data());
497519

498520
if (no_async) {
499521
tc->wait(hdl);
522+
if (oStride != 1) {
523+
unpack1(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
524+
oDataPtr);
525+
delete[](char *) rBuff;
526+
}
500527
return nullptr;
501528
}
502529

503-
auto wait = [tc = tc, hdl = hdl, sendbuff = std::move(sendbuff),
504-
sszs = std::move(sszs), soffs = std::move(soffs),
505-
rszs = std::move(rszs),
506-
roffs = std::move(roffs)]() { tc->wait(hdl); };
530+
auto wait = [tc, hdl, oStride, rBuff, sharpytype, oDataShapePtr,
531+
oDataStridesPtr, oNDims, oDataPtr,
532+
sendbuff = std::move(sendbuff), sszs = std::move(sszs),
533+
soffs = std::move(soffs), rszs = std::move(rszs),
534+
roffs = std::move(roffs)]() {
535+
tc->wait(hdl);
536+
if (oStride != 1) {
537+
unpack1(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
538+
oDataPtr);
539+
delete[](char *) rBuff;
540+
}
541+
};
507542
assert(sendbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() &&
508543
roffs.empty());
509544
return mkWaitHandle(std::move(wait));

0 commit comments

Comments
 (0)