Skip to content

Commit 189f892

Browse files
committed
wave-equation: simplify reduction calls
1 parent 8ab6ee7 commit 189f892

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

examples/wave_equation.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode):
6161
def transpose(a):
6262
return np.permute_dims(a, [1, 0])
6363

64-
all_axes = [0, 1]
6564
init(False)
6665

6766
elif backend == "numpy":
@@ -76,7 +75,6 @@ def transpose(a):
7675
transpose = np.transpose
7776

7877
fini = sync = lambda x=None: None
79-
all_axes = None
8078
else:
8179
raise ValueError(f'Unknown backend: "{backend}"')
8280

@@ -240,9 +238,9 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
240238
t = i * dt
241239

242240
if t >= next_t_export - 1e-8:
243-
_elev_max = np.max(e, all_axes)
244-
_u_max = np.max(u, all_axes)
245-
_total_v = np.sum(e + h, all_axes)
241+
_elev_max = np.max(e)
242+
_u_max = np.max(u)
243+
_total_v = np.sum(e + h)
246244

247245
elev_max = float(_elev_max)
248246
u_max = float(_u_max)
@@ -279,7 +277,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
279277

280278
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
281279
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
282-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
280+
err_L2 = math.sqrt(float(np.sum(err2)))
283281
info(f"L2 error: {err_L2:7.5e}")
284282

285283
if nx == 128 and ny == 128 and not benchmark_mode:

0 commit comments

Comments
 (0)