@@ -18,14 +18,14 @@ def __init__(self, iterations, n, tile_size, star, radius):
18
18
self .out : pk .View2D [pk .double ] = pk .View ([self .n , self .n ], pk .double , layout = pk .Layout .LayoutRight )
19
19
self .norm : float = 0
20
20
21
- self .stencil_time : float = 0
21
+ self .stencil_time : float = 0
22
22
23
23
@pk .main
24
24
def run (self ):
25
25
t : int = tile_size
26
26
r : int = radius
27
27
28
- pk .parallel_for (pk .MDRangePolicy ([0 ,0 ], [n , n ], [t , t ]),
28
+ pk .parallel_for (pk .MDRangePolicy ([0 ,0 ], [n , n ], [t , t ]),
29
29
self .init )
30
30
pk .fence ()
31
31
@@ -34,7 +34,7 @@ def run(self):
34
34
for i in range (iterations ):
35
35
if (i == 1 ):
36
36
pk .fence ()
37
-
37
+
38
38
if r == 1 :
39
39
# star1 stencil
40
40
pk .parallel_for ("stencil" , pk .MDRangePolicy ([r ,r ], [n - r , n - r ], [t , t ]), self .star1 )
@@ -45,8 +45,8 @@ def run(self):
45
45
# star3 stencil
46
46
pk .parallel_for ("stencil" , pk .MDRangePolicy ([r ,r ], [n - r , n - r ], [t , t ]), self .star3 )
47
47
48
-
49
- pk .parallel_for (pk .MDRangePolicy ([0 ,0 ], [n , n ], [t , t ]),
48
+
49
+ pk .parallel_for (pk .MDRangePolicy ([0 ,0 ], [n , n ], [t , t ]),
50
50
self .increment )
51
51
52
52
pk .fence ()
@@ -55,7 +55,7 @@ def run(self):
55
55
active_points : int = (n - 2 * r )* (n - 2 * r )
56
56
57
57
# verify correctness
58
- self .norm = pk .parallel_reduce (pk .MDRangePolicy ([r , r ], [n - r , n - r ], [t , t ]),
58
+ self .norm = pk .parallel_reduce (pk .MDRangePolicy ([r , r ], [n - r , n - r ], [t , t ]),
59
59
self .norm_reduce )
60
60
pk .fence ()
61
61
self .norm /= active_points
@@ -78,7 +78,7 @@ def increment(self, i: int, j: int):
78
78
79
79
@pk .workunit
80
80
def norm_reduce (self , i : int , j : int , acc : pk .Acc [pk .double ]):
81
- acc += abs (self .out [i ][j ])
81
+ acc += abs (self .out [i ][j ])
82
82
83
83
# @pk.callback
84
84
# def print_result(self):
@@ -121,7 +121,7 @@ def star3(self, i: int, j: int):
121
121
+ self .inp [i ][j + 2 ] * 0.08333333333333333 \
122
122
+ self .inp [i ][j + 3 ] * 0.05555555555555555
123
123
124
- if __name__ == "__main__" :
124
+ def run () -> None :
125
125
parser = argparse .ArgumentParser ()
126
126
parser .add_argument ('iterations' , type = int )
127
127
parser .add_argument ('n' , type = int )
@@ -169,9 +169,11 @@ def star3(self, i: int, j: int):
169
169
170
170
n = 2 ** n
171
171
print ("Number of iterations = " , iterations )
172
- print ("Grid size = " , n )
172
+ print ("Grid size = " , n )
173
173
print ("Tile size = " , tile_size )
174
174
print ("Type of stencil = " , "star" if star else "grid" )
175
175
print ("Radius of stencil = " , radius )
176
176
pk .execute (pk .ExecutionSpace .Default , main (iterations , n , tile_size , star , radius ))
177
177
178
+ if __name__ == "__main__" :
179
+ run ()
0 commit comments