@@ -58,22 +58,165 @@ function transform_gpu!(def, constargs, force_inbounds)
58
58
end
59
59
end
60
60
pushfirst! (def[:args ], :__ctx__ )
61
- body = def[:body ]
61
+ new_stmts = Expr[]
62
+ body = MacroTools. flatten (def[:body ])
63
+ stmts = body. args
64
+ push! (new_stmts, Expr (:aliasscope ))
65
+ push! (new_stmts, :(__active_lane__ = $ __validindex (__ctx__)))
62
66
if force_inbounds
63
- body = quote
64
- @inbounds $ (body)
65
- end
67
+ push! (new_stmts, Expr (:inbounds , true ))
66
68
end
67
- body = quote
68
- if $ __validindex (__ctx__)
69
- $ (body)
70
- end
71
- return nothing
69
+ append! (new_stmts, split (emit_gpu, body. args))
70
+ if force_inbounds
71
+ push! (new_stmts, Expr (:inbounds , :pop ))
72
72
end
73
+ push! (new_stmts, Expr (:popaliasscope ))
74
+ push! (new_stmts, :(return nothing ))
73
75
def[:body ] = Expr (
74
76
:let ,
75
77
Expr (:block , let_constargs... ),
76
- body ,
78
+ Expr ( :block , new_stmts ... ) ,
77
79
)
78
80
return
79
81
end
82
+
83
+ struct WorkgroupLoop
84
+ indicies:: Vector{Any}
85
+ stmts:: Vector{Any}
86
+ allocations:: Vector{Any}
87
+ private_allocations:: Vector{Any}
88
+ private:: Set{Symbol}
89
+ end
90
+
91
+ is_sync (expr) = @capture (expr, @synchronize () | @synchronize (a_))
92
+
93
+ function is_scope_construct (expr:: Expr )
94
+ return expr. head === :block # ||
95
+ # expr.head === :let
96
+ end
97
+
98
+ function find_sync (stmt)
99
+ result = false
100
+ postwalk (stmt) do expr
101
+ result |= is_sync (expr)
102
+ expr
103
+ end
104
+ return result
105
+ end
106
+
107
+ # TODO proper handling of LineInfo
108
+ function split (
109
+ emit,
110
+ stmts,
111
+ indicies = Any[], private = Set {Symbol} (),
112
+ )
113
+ # 1. Split the code into blocks separated by `@synchronize`
114
+ # 2. Aggregate `@index` expressions
115
+ # 3. Hoist allocations
116
+ # 4. Hoist uniforms
117
+
118
+ current = Any[]
119
+ allocations = Any[]
120
+ private_allocations = Any[]
121
+ new_stmts = Any[]
122
+ for stmt in stmts
123
+ has_sync = find_sync (stmt)
124
+ if has_sync
125
+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private))
126
+ push! (new_stmts, emit (loop))
127
+ allocations = Any[]
128
+ private_allocations = Any[]
129
+ current = Any[]
130
+ is_sync (stmt) && continue
131
+
132
+ # Recurse into scope constructs
133
+ # TODO : This currently implements hard scoping
134
+ # probably need to implemet soft scoping
135
+ # by not deepcopying the environment.
136
+ recurse (x) = x
137
+ function recurse (expr:: Expr )
138
+ expr = unblock (expr)
139
+ if is_scope_construct (expr) && any (find_sync, expr. args)
140
+ new_args = unblock (split (emit, expr. args, deepcopy (indicies), deepcopy (private)))
141
+ return Expr (expr. head, new_args... )
142
+ else
143
+ return Expr (expr. head, map (recurse, expr. args)... )
144
+ end
145
+ end
146
+ push! (new_stmts, recurse (stmt))
147
+ continue
148
+ end
149
+
150
+ if @capture (stmt, @uniform x_)
151
+ push! (allocations, stmt)
152
+ continue
153
+ elseif @capture (stmt, @private lhs_ = rhs_)
154
+ push! (private, lhs)
155
+ push! (private_allocations, :($ lhs = $ rhs))
156
+ continue
157
+ elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
158
+ if @capture (rhs, @index (args__))
159
+ push! (indicies, stmt)
160
+ continue
161
+ elseif @capture (rhs, @localmem (args__) | @uniform (args__))
162
+ push! (allocations, stmt)
163
+ continue
164
+ elseif @capture (rhs, @private (T_, dims_))
165
+ # Implement the legacy `mem = @private T dims` as
166
+ # mem = Scratchpad(T, Val(dims))
167
+
168
+ if dims isa Integer
169
+ dims = (dims,)
170
+ end
171
+ alloc = :($ Scratchpad (__ctx__, $ T, Val ($ dims)))
172
+ push! (allocations, :($ lhs = $ alloc))
173
+ push! (private, lhs)
174
+ continue
175
+ end
176
+ end
177
+
178
+ push! (current, stmt)
179
+ end
180
+
181
+ # everything since the last `@synchronize`
182
+ if ! isempty (current)
183
+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private))
184
+ push! (new_stmts, emit (loop))
185
+ end
186
+ return new_stmts
187
+ end
188
+
189
+ function emit_gpu (loop)
190
+ stmts = Any[]
191
+ append! (stmts, loop. allocations)
192
+ for stmt in loop. private_allocations
193
+ if @capture (stmt, lhs_ = rhs_)
194
+ push! (stmts, :($ lhs = $ rhs))
195
+ else
196
+ error (" @private $stmt not an assignment" )
197
+ end
198
+ end
199
+
200
+ # don't emit empty loops
201
+ if ! (isempty (loop. stmts) || all (s -> s isa LineNumberNode, loop. stmts))
202
+ body = Expr (:block , loop. stmts... )
203
+ body = postwalk (body) do expr
204
+ if @capture (expr, lhs_ = rhs_)
205
+ if lhs in loop. private
206
+ error (" Can't assign to variables marked private" )
207
+ end
208
+ end
209
+ return expr
210
+ end
211
+ loopexpr = quote
212
+ if __active_lane__
213
+ $ (loop. indicies... )
214
+ $ (unblock (body))
215
+ end
216
+ $ __synchronize ()
217
+ end
218
+ push! (stmts, loopexpr)
219
+ end
220
+
221
+ return unblock (Expr (:block , stmts... ))
222
+ end
0 commit comments