29
29
import argparse
30
30
31
31
32
- def run (n , backend , benchmark_mode , correctness_test ):
32
+ def run (n , backend , datatype , benchmark_mode ):
33
33
if backend == "ddpt" :
34
34
import ddptensor as np
35
35
from ddptensor .numpy import fromfunction
@@ -64,8 +64,11 @@ def info(s):
64
64
65
65
info (f"Using backend: { backend } " )
66
66
67
- if correctness_test :
68
- n = 10
67
+ dtype = {
68
+ "f64" : np .float64 ,
69
+ "f32" : np .float32 ,
70
+ }[datatype ]
71
+ info (f"Datatype: { datatype } " )
69
72
70
73
# constants
71
74
g = 9.81
@@ -92,20 +95,16 @@ def info(s):
92
95
t_end = 1.0
93
96
94
97
# coordinate arrays
95
- x_t_2d = fromfunction (
96
- lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = np .float64
97
- )
98
- y_t_2d = fromfunction (
99
- lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = np .float64
100
- )
101
- x_u_2d = fromfunction (lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = np .float64 )
98
+ x_t_2d = fromfunction (lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype )
99
+ y_t_2d = fromfunction (lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype )
100
+ x_u_2d = fromfunction (lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype )
102
101
y_u_2d = fromfunction (
103
- lambda i , j : ymin + j * dy + dy / 2 , (nx + 1 , ny ), dtype = np . float64
102
+ lambda i , j : ymin + j * dy + dy / 2 , (nx + 1 , ny ), dtype = dtype
104
103
)
105
104
x_v_2d = fromfunction (
106
- lambda i , j : xmin + i * dx + dx / 2 , (nx , ny + 1 ), dtype = np . float64
105
+ lambda i , j : xmin + i * dx + dx / 2 , (nx , ny + 1 ), dtype = dtype
107
106
)
108
- y_v_2d = fromfunction (lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = np . float64 )
107
+ y_v_2d = fromfunction (lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype )
109
108
110
109
T_shape = (nx , ny )
111
110
U_shape = (nx + 1 , ny )
@@ -122,32 +121,32 @@ def info(s):
122
121
info (f"Total DOFs: { dofs_T + dofs_U + dofs_V } " )
123
122
124
123
# prognostic variables: elevation, (u, v) velocity
125
- e = np .full (T_shape , 0.0 , np . float64 )
126
- u = np .full (U_shape , 0.0 , np . float64 )
127
- v = np .full (V_shape , 0.0 , np . float64 )
124
+ e = np .full (T_shape , 0.0 , dtype )
125
+ u = np .full (U_shape , 0.0 , dtype )
126
+ v = np .full (V_shape , 0.0 , dtype )
128
127
129
128
# potential vorticity
130
- q = np .full (F_shape , 0.0 , np . float64 )
129
+ q = np .full (F_shape , 0.0 , dtype )
131
130
132
131
# bathymetry
133
- h = np .full (T_shape , 0.0 , np . float64 )
132
+ h = np .full (T_shape , 0.0 , dtype )
134
133
135
- hu = np .full (U_shape , 0.0 , np . float64 )
136
- hv = np .full (V_shape , 0.0 , np . float64 )
134
+ hu = np .full (U_shape , 0.0 , dtype )
135
+ hv = np .full (V_shape , 0.0 , dtype )
137
136
138
- dudy = np .full (F_shape , 0.0 , np . float64 )
139
- dvdx = np .full (F_shape , 0.0 , np . float64 )
137
+ dudy = np .full (F_shape , 0.0 , dtype )
138
+ dvdx = np .full (F_shape , 0.0 , dtype )
140
139
141
140
# vector invariant form
142
- H_at_f = np .full (F_shape , 0.0 , np . float64 )
141
+ H_at_f = np .full (F_shape , 0.0 , dtype )
143
142
144
143
# auxiliary variables for RK time integration
145
- e1 = np .full (T_shape , 0.0 , np . float64 )
146
- u1 = np .full (U_shape , 0.0 , np . float64 )
147
- v1 = np .full (V_shape , 0.0 , np . float64 )
148
- e2 = np .full (T_shape , 0.0 , np . float64 )
149
- u2 = np .full (U_shape , 0.0 , np . float64 )
150
- v2 = np .full (V_shape , 0.0 , np . float64 )
144
+ e1 = np .full (T_shape , 0.0 , dtype )
145
+ u1 = np .full (U_shape , 0.0 , dtype )
146
+ v1 = np .full (V_shape , 0.0 , dtype )
147
+ e2 = np .full (T_shape , 0.0 , dtype )
148
+ u2 = np .full (U_shape , 0.0 , dtype )
149
+ v2 = np .full (V_shape , 0.0 , dtype )
151
150
152
151
def exact_solution (t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d ):
153
152
"""
@@ -176,7 +175,7 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
176
175
Water depth at rest
177
176
"""
178
177
bath = 1.0
179
- return bath * np .full (T_shape , 1.0 , np . float64 )
178
+ return bath * np .full (T_shape , 1.0 , dtype )
180
179
181
180
# inital elevation
182
181
u0 , v0 , e0 = exact_solution (0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d )
@@ -200,10 +199,6 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
200
199
dt = 1e-5
201
200
nt = 100
202
201
t_export = dt * 25
203
- if correctness_test :
204
- dt = 0.02
205
- nt = 10
206
- t_export = dt * 2
207
202
208
203
info (f"Time step: { dt } s" )
209
204
info (f"Total run time: { t_end } s, { nt } time steps" )
@@ -381,20 +376,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
381
376
err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
382
377
info (f"L2 error: { err_L2 :7.15e} " )
383
378
384
- if correctness_test :
385
- assert numpy .allclose (err_L2 , 3.687334565903038e-04 ), "L2 error does not match"
386
- info ("SUCCESS" )
387
- elif nx < 128 or ny < 128 :
379
+ if nx < 128 or ny < 128 :
388
380
info ("Skipping correctness test due to small problem size." )
389
381
elif not benchmark_mode :
390
- tolerance_ene = 1e-8
382
+ tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
391
383
assert (
392
384
diff_e < tolerance_ene
393
385
), f"Energy error exceeds tolerance: { diff_e } > { tolerance_ene } "
394
386
if nx == 128 and ny == 128 :
395
- assert numpy .allclose (
396
- err_L2 , 4.315799035627906e-05
397
- ), "L2 error does not match"
387
+ if datatype == "f32" :
388
+ assert numpy .allclose (
389
+ err_L2 , 4.3127859e-05 , rtol = 1e-5
390
+ ), "L2 error does not match"
391
+ else :
392
+ assert numpy .allclose (
393
+ err_L2 , 4.315799035627906e-05
394
+ ), "L2 error does not match"
398
395
else :
399
396
tolerance_l2 = 1e-4
400
397
assert (
@@ -423,12 +420,6 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
423
420
action = "store_true" ,
424
421
help = "Run a fixed number of time steps." ,
425
422
)
426
- parser .add_argument (
427
- "-ct" ,
428
- "--correctness-test" ,
429
- action = "store_true" ,
430
- help = "Run a minimal correctness test." ,
431
- )
432
423
parser .add_argument (
433
424
"-b" ,
434
425
"--backend" ,
@@ -437,5 +428,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
437
428
choices = ["ddpt" , "numpy" ],
438
429
help = "Backend to use." ,
439
430
)
431
+ parser .add_argument (
432
+ "-d" ,
433
+ "--datatype" ,
434
+ type = str ,
435
+ default = "f64" ,
436
+ choices = ["f32" , "f64" ],
437
+ help = "Datatype for model state variables" ,
438
+ )
440
439
args = parser .parse_args ()
441
- run (args .resolution , args .backend , args .benchmark_mode , args .correctness_test )
440
+ run (
441
+ args .resolution ,
442
+ args .backend ,
443
+ args .datatype ,
444
+ args .benchmark_mode ,
445
+ )
0 commit comments