28
28
import time as time_mod
29
29
import argparse
30
30
31
+ import os
32
+
33
+ device = os .getenv ("SHARPY_USE_GPU" , "" )
34
+
31
35
32
36
def run (n , backend , datatype , benchmark_mode ):
33
37
if backend == "sharpy" :
@@ -94,16 +98,24 @@ def info(s):
94
98
t_end = 1.0
95
99
96
100
# coordinate arrays
97
- x_t_2d = fromfunction (lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype )
98
- y_t_2d = fromfunction (lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype )
99
- x_u_2d = fromfunction (lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype )
101
+ x_t_2d = fromfunction (
102
+ lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype , device = device
103
+ )
104
+ y_t_2d = fromfunction (
105
+ lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype , device = device
106
+ )
107
+ x_u_2d = fromfunction (
108
+ lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype , device = device
109
+ )
100
110
y_u_2d = fromfunction (
101
- lambda i , j : ymin + j * dy + dy / 2 , (nx + 1 , ny ), dtype = dtype
111
+ lambda i , j : ymin + j * dy + dy / 2 , (nx + 1 , ny ), dtype = dtype , device = device
102
112
)
103
113
x_v_2d = fromfunction (
104
- lambda i , j : xmin + i * dx + dx / 2 , (nx , ny + 1 ), dtype = dtype
114
+ lambda i , j : xmin + i * dx + dx / 2 , (nx , ny + 1 ), dtype = dtype , device = device
115
+ )
116
+ y_v_2d = fromfunction (
117
+ lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype , device = device
105
118
)
106
- y_v_2d = fromfunction (lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype )
107
119
108
120
T_shape = (nx , ny )
109
121
U_shape = (nx + 1 , ny )
@@ -120,32 +132,32 @@ def info(s):
120
132
info (f"Total DOFs: { dofs_T + dofs_U + dofs_V } " )
121
133
122
134
# prognostic variables: elevation, (u, v) velocity
123
- e = np .full (T_shape , 0.0 , dtype )
124
- u = np .full (U_shape , 0.0 , dtype )
125
- v = np .full (V_shape , 0.0 , dtype )
135
+ e = np .full (T_shape , 0.0 , dtype , device = device )
136
+ u = np .full (U_shape , 0.0 , dtype , device = device )
137
+ v = np .full (V_shape , 0.0 , dtype , device = device )
126
138
127
139
# potential vorticity
128
- q = np .full (F_shape , 0.0 , dtype )
140
+ q = np .full (F_shape , 0.0 , dtype , device = device )
129
141
130
142
# bathymetry
131
- h = np .full (T_shape , 0.0 , dtype )
143
+ h = np .full (T_shape , 0.0 , dtype , device = device )
132
144
133
- hu = np .full (U_shape , 0.0 , dtype )
134
- hv = np .full (V_shape , 0.0 , dtype )
145
+ hu = np .full (U_shape , 0.0 , dtype , device = device )
146
+ hv = np .full (V_shape , 0.0 , dtype , device = device )
135
147
136
- dudy = np .full (F_shape , 0.0 , dtype )
137
- dvdx = np .full (F_shape , 0.0 , dtype )
148
+ dudy = np .full (F_shape , 0.0 , dtype , device = device )
149
+ dvdx = np .full (F_shape , 0.0 , dtype , device = device )
138
150
139
151
# vector invariant form
140
- H_at_f = np .full (F_shape , 0.0 , dtype )
152
+ H_at_f = np .full (F_shape , 0.0 , dtype , device = device )
141
153
142
154
# auxiliary variables for RK time integration
143
- e1 = np .full (T_shape , 0.0 , dtype )
144
- u1 = np .full (U_shape , 0.0 , dtype )
145
- v1 = np .full (V_shape , 0.0 , dtype )
146
- e2 = np .full (T_shape , 0.0 , dtype )
147
- u2 = np .full (U_shape , 0.0 , dtype )
148
- v2 = np .full (V_shape , 0.0 , dtype )
155
+ e1 = np .full (T_shape , 0.0 , dtype , device = device )
156
+ u1 = np .full (U_shape , 0.0 , dtype , device = device )
157
+ v1 = np .full (V_shape , 0.0 , dtype , device = device )
158
+ e2 = np .full (T_shape , 0.0 , dtype , device = device )
159
+ u2 = np .full (U_shape , 0.0 , dtype , device = device )
160
+ v2 = np .full (V_shape , 0.0 , dtype , device = device )
149
161
150
162
def exact_solution (t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d ):
151
163
"""
@@ -174,7 +186,7 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
174
186
Water depth at rest
175
187
"""
176
188
bath = 1.0
177
- return bath * np .full (T_shape , 1.0 , dtype )
189
+ return bath * np .full (T_shape , 1.0 , dtype , device = device )
178
190
179
191
# inital elevation
180
192
u0 , v0 , e0 = exact_solution (0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d )
0 commit comments