@@ -134,7 +134,7 @@ def run(n, backend, datatype, benchmark_mode):
134134 info (f"Total DOFs: { dofs_T + dofs_U + dofs_V } " )
135135
136136 # prognostic variables: elevation, (u, v) velocity
137- # e = create_full(T_shape, 0.0, dtype)
137+ e = create_full (T_shape , 0.0 , dtype )
138138 u = create_full (U_shape , 0.0 , dtype )
139139 v = create_full (V_shape , 0.0 , dtype )
140140
@@ -167,9 +167,7 @@ def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
167167 return amp * sol_x * sol_y * sol_t
168168
169169 # initial elevation
170- # e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly)
171- # NOTE assignment fails, do not pre-allocate e
172- e = exact_elev (0.0 , x_t_2d , y_t_2d , lx , ly ).to_device (device )
170+ e [:, :] = exact_elev (0.0 , x_t_2d , y_t_2d , lx , ly )
173171 sync ()
174172
175173 # compute time step
@@ -235,8 +233,8 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
235233 t = i * dt
236234
237235 if t >= next_t_export - 1e-8 :
238- _elev_max = 0 # np.max(e, all_axes)
239- _u_max = 0 # np.max(u, all_axes)
236+ _elev_max = e [ 0 , 0 ]. to_device () # np.max(e, all_axes)
237+ _u_max = u [ 0 , 0 ]. to_device () # np.max(u, all_axes)
240238 _total_v = 0 # np.sum(e + h, all_axes)
241239
242240 elev_max = float (_elev_max )
0 commit comments