Skip to content

Commit 173cb06

Browse files
authored
Add datatype flag to wave and shallow water examples (#61)
1 parent 7c98509 commit 173cb06

File tree

3 files changed

+92
-85
lines changed

3 files changed

+92
-85
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,14 @@ jobs:
147147
mpirun -n 2 -genv DDPT_FALLBACK=numpy python -u ./stencil-2d.py 5 2048 star 2
148148
mpirun -n 3 -genv DDPT_FALLBACK=numpy python -u ./stencil-2d.py 5 2048 star 2
149149
mpirun -n 4 -genv DDPT_FALLBACK=numpy python -u ./stencil-2d.py 5 2048 star 2
150-
python -u ./wave_equation.py
150+
python -u ./wave_equation.py -d f32
151+
python -u ./wave_equation.py -d f64
151152
DDPT_FORCE_DIST=1 python -u ./wave_equation.py
152153
mpirun -n 2 python -u ./wave_equation.py
153154
mpirun -n 3 python -u ./wave_equation.py
154155
mpirun -n 4 python -u ./wave_equation.py
155-
python -u ./shallow_water.py
156+
python -u ./shallow_water.py -d f32
157+
python -u ./shallow_water.py -d f64
156158
cd -
157159
- name: Cleanup
158160
run: |

examples/shallow_water.py

Lines changed: 53 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import argparse
3030

3131

32-
def run(n, backend, benchmark_mode, correctness_test):
32+
def run(n, backend, datatype, benchmark_mode):
3333
if backend == "ddpt":
3434
import ddptensor as np
3535
from ddptensor.numpy import fromfunction
@@ -64,8 +64,11 @@ def info(s):
6464

6565
info(f"Using backend: {backend}")
6666

67-
if correctness_test:
68-
n = 10
67+
dtype = {
68+
"f64": np.float64,
69+
"f32": np.float32,
70+
}[datatype]
71+
info(f"Datatype: {datatype}")
6972

7073
# constants
7174
g = 9.81
@@ -92,20 +95,16 @@ def info(s):
9295
t_end = 1.0
9396

9497
# coordinate arrays
95-
x_t_2d = fromfunction(
96-
lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=np.float64
97-
)
98-
y_t_2d = fromfunction(
99-
lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=np.float64
100-
)
101-
x_u_2d = fromfunction(lambda i, j: xmin + i * dx, (nx + 1, ny), dtype=np.float64)
98+
x_t_2d = fromfunction(lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=dtype)
99+
y_t_2d = fromfunction(lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=dtype)
100+
x_u_2d = fromfunction(lambda i, j: xmin + i * dx, (nx + 1, ny), dtype=dtype)
102101
y_u_2d = fromfunction(
103-
lambda i, j: ymin + j * dy + dy / 2, (nx + 1, ny), dtype=np.float64
102+
lambda i, j: ymin + j * dy + dy / 2, (nx + 1, ny), dtype=dtype
104103
)
105104
x_v_2d = fromfunction(
106-
lambda i, j: xmin + i * dx + dx / 2, (nx, ny + 1), dtype=np.float64
105+
lambda i, j: xmin + i * dx + dx / 2, (nx, ny + 1), dtype=dtype
107106
)
108-
y_v_2d = fromfunction(lambda i, j: ymin + j * dy, (nx, ny + 1), dtype=np.float64)
107+
y_v_2d = fromfunction(lambda i, j: ymin + j * dy, (nx, ny + 1), dtype=dtype)
109108

110109
T_shape = (nx, ny)
111110
U_shape = (nx + 1, ny)
@@ -122,32 +121,32 @@ def info(s):
122121
info(f"Total DOFs: {dofs_T + dofs_U + dofs_V}")
123122

124123
# prognostic variables: elevation, (u, v) velocity
125-
e = np.full(T_shape, 0.0, np.float64)
126-
u = np.full(U_shape, 0.0, np.float64)
127-
v = np.full(V_shape, 0.0, np.float64)
124+
e = np.full(T_shape, 0.0, dtype)
125+
u = np.full(U_shape, 0.0, dtype)
126+
v = np.full(V_shape, 0.0, dtype)
128127

129128
# potential vorticity
130-
q = np.full(F_shape, 0.0, np.float64)
129+
q = np.full(F_shape, 0.0, dtype)
131130

132131
# bathymetry
133-
h = np.full(T_shape, 0.0, np.float64)
132+
h = np.full(T_shape, 0.0, dtype)
134133

135-
hu = np.full(U_shape, 0.0, np.float64)
136-
hv = np.full(V_shape, 0.0, np.float64)
134+
hu = np.full(U_shape, 0.0, dtype)
135+
hv = np.full(V_shape, 0.0, dtype)
137136

138-
dudy = np.full(F_shape, 0.0, np.float64)
139-
dvdx = np.full(F_shape, 0.0, np.float64)
137+
dudy = np.full(F_shape, 0.0, dtype)
138+
dvdx = np.full(F_shape, 0.0, dtype)
140139

141140
# vector invariant form
142-
H_at_f = np.full(F_shape, 0.0, np.float64)
141+
H_at_f = np.full(F_shape, 0.0, dtype)
143142

144143
# auxiliary variables for RK time integration
145-
e1 = np.full(T_shape, 0.0, np.float64)
146-
u1 = np.full(U_shape, 0.0, np.float64)
147-
v1 = np.full(V_shape, 0.0, np.float64)
148-
e2 = np.full(T_shape, 0.0, np.float64)
149-
u2 = np.full(U_shape, 0.0, np.float64)
150-
v2 = np.full(V_shape, 0.0, np.float64)
144+
e1 = np.full(T_shape, 0.0, dtype)
145+
u1 = np.full(U_shape, 0.0, dtype)
146+
v1 = np.full(V_shape, 0.0, dtype)
147+
e2 = np.full(T_shape, 0.0, dtype)
148+
u2 = np.full(U_shape, 0.0, dtype)
149+
v2 = np.full(V_shape, 0.0, dtype)
151150

152151
def exact_solution(t, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d):
153152
"""
@@ -176,7 +175,7 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
176175
Water depth at rest
177176
"""
178177
bath = 1.0
179-
return bath * np.full(T_shape, 1.0, np.float64)
178+
return bath * np.full(T_shape, 1.0, dtype)
180179

181180
# inital elevation
182181
u0, v0, e0 = exact_solution(0, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d)
@@ -200,10 +199,6 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
200199
dt = 1e-5
201200
nt = 100
202201
t_export = dt * 25
203-
if correctness_test:
204-
dt = 0.02
205-
nt = 10
206-
t_export = dt * 2
207202

208203
info(f"Time step: {dt} s")
209204
info(f"Total run time: {t_end} s, {nt} time steps")
@@ -381,20 +376,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
381376
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
382377
info(f"L2 error: {err_L2:7.15e}")
383378

384-
if correctness_test:
385-
assert numpy.allclose(err_L2, 3.687334565903038e-04), "L2 error does not match"
386-
info("SUCCESS")
387-
elif nx < 128 or ny < 128:
379+
if nx < 128 or ny < 128:
388380
info("Skipping correctness test due to small problem size.")
389381
elif not benchmark_mode:
390-
tolerance_ene = 1e-8
382+
tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
391383
assert (
392384
diff_e < tolerance_ene
393385
), f"Energy error exceeds tolerance: {diff_e} > {tolerance_ene}"
394386
if nx == 128 and ny == 128:
395-
assert numpy.allclose(
396-
err_L2, 4.315799035627906e-05
397-
), "L2 error does not match"
387+
if datatype == "f32":
388+
assert numpy.allclose(
389+
err_L2, 4.3127859e-05, rtol=1e-5
390+
), "L2 error does not match"
391+
else:
392+
assert numpy.allclose(
393+
err_L2, 4.315799035627906e-05
394+
), "L2 error does not match"
398395
else:
399396
tolerance_l2 = 1e-4
400397
assert (
@@ -423,12 +420,6 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
423420
action="store_true",
424421
help="Run a fixed number of time steps.",
425422
)
426-
parser.add_argument(
427-
"-ct",
428-
"--correctness-test",
429-
action="store_true",
430-
help="Run a minimal correctness test.",
431-
)
432423
parser.add_argument(
433424
"-b",
434425
"--backend",
@@ -437,5 +428,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
437428
choices=["ddpt", "numpy"],
438429
help="Backend to use.",
439430
)
431+
parser.add_argument(
432+
"-d",
433+
"--datatype",
434+
type=str,
435+
default="f64",
436+
choices=["f32", "f64"],
437+
help="Datatype for model state variables",
438+
)
440439
args = parser.parse_args()
441-
run(args.resolution, args.backend, args.benchmark_mode, args.correctness_test)
440+
run(
441+
args.resolution,
442+
args.backend,
443+
args.datatype,
444+
args.benchmark_mode,
445+
)

examples/wave_equation.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import argparse
3030

3131

32-
def run(n, backend, benchmark_mode, correctness_test):
32+
def run(n, backend, datatype, benchmark_mode):
3333
if backend == "ddpt":
3434
import ddptensor as np
3535
from ddptensor.numpy import fromfunction
@@ -64,8 +64,11 @@ def info(s):
6464

6565
info(f"Using backend: {backend}")
6666

67-
if correctness_test:
68-
n = 10
67+
dtype = {
68+
"f64": np.float64,
69+
"f32": np.float32,
70+
}[datatype]
71+
info(f"Datatype: {datatype}")
6972

7073
# constants
7174
h = 1.0
@@ -92,12 +95,8 @@ def info(s):
9295
t_end = 1.0
9396

9497
# coordinate arrays
95-
x_t_2d = fromfunction(
96-
lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=np.float64
97-
)
98-
y_t_2d = fromfunction(
99-
lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=np.float64
100-
)
98+
x_t_2d = fromfunction(lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=dtype)
99+
y_t_2d = fromfunction(lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=dtype)
101100

102101
T_shape = (nx, ny)
103102
U_shape = (nx + 1, ny)
@@ -113,17 +112,17 @@ def info(s):
113112
info(f"Total DOFs: {dofs_T + dofs_U + dofs_V}")
114113

115114
# prognostic variables: elevation, (u, v) velocity
116-
e = np.full(T_shape, 0.0, np.float64)
117-
u = np.full(U_shape, 0.0, np.float64)
118-
v = np.full(V_shape, 0.0, np.float64)
115+
e = np.full(T_shape, 0.0, dtype)
116+
u = np.full(U_shape, 0.0, dtype)
117+
v = np.full(V_shape, 0.0, dtype)
119118

120119
# auxiliary variables for RK time integration
121-
e1 = np.full(T_shape, 0.0, np.float64)
122-
u1 = np.full(U_shape, 0.0, np.float64)
123-
v1 = np.full(V_shape, 0.0, np.float64)
124-
e2 = np.full(T_shape, 0.0, np.float64)
125-
u2 = np.full(U_shape, 0.0, np.float64)
126-
v2 = np.full(V_shape, 0.0, np.float64)
120+
e1 = np.full(T_shape, 0.0, dtype)
121+
u1 = np.full(U_shape, 0.0, dtype)
122+
v1 = np.full(V_shape, 0.0, dtype)
123+
e2 = np.full(T_shape, 0.0, dtype)
124+
u2 = np.full(U_shape, 0.0, dtype)
125+
v2 = np.full(V_shape, 0.0, dtype)
127126

128127
def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
129128
"""
@@ -156,10 +155,6 @@ def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
156155
dt = 1e-5
157156
nt = 100
158157
t_export = dt * 25
159-
if correctness_test:
160-
dt = 0.02
161-
nt = 10
162-
t_export = dt * 2
163158

164159
info(f"Time step: {dt} s")
165160
info(f"Total run time: {t_end} s, {nt} time steps")
@@ -253,11 +248,10 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
253248
info(f"L2 error: {err_L2:7.5e}")
254249

255250
if nx == 128 and ny == 128 and not benchmark_mode:
256-
assert numpy.allclose(err_L2, 7.224068445111e-03)
257-
info("SUCCESS")
258-
259-
if correctness_test:
260-
assert numpy.allclose(err_L2, 1.317066179876e-02)
251+
if datatype == "f32":
252+
assert numpy.allclose(err_L2, 7.2235471e-03, rtol=1e-4)
253+
else:
254+
assert numpy.allclose(err_L2, 7.224068445111e-03)
261255
info("SUCCESS")
262256

263257
fini()
@@ -281,12 +275,6 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
281275
action="store_true",
282276
help="Run a fixed number of time steps.",
283277
)
284-
parser.add_argument(
285-
"-ct",
286-
"--correctness-test",
287-
action="store_true",
288-
help="Run a minimal correctness test.",
289-
)
290278
parser.add_argument(
291279
"-b",
292280
"--backend",
@@ -295,5 +283,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
295283
choices=["ddpt", "numpy"],
296284
help="Backend to use.",
297285
)
286+
parser.add_argument(
287+
"-d",
288+
"--datatype",
289+
type=str,
290+
default="f64",
291+
choices=["f32", "f64"],
292+
help="Datatype for model state variables",
293+
)
298294
args = parser.parse_args()
299-
run(args.resolution, args.backend, args.benchmark_mode, args.correctness_test)
295+
run(
296+
args.resolution,
297+
args.backend,
298+
args.datatype,
299+
args.benchmark_mode,
300+
)

0 commit comments

Comments
 (0)