Skip to content

Commit a48a158

Browse files
committed
Forbid divergent execution of work-group barriers
1 parent a6ae55b commit a48a158

File tree

2 files changed

+162
-11
lines changed

2 files changed

+162
-11
lines changed

src/KernelAbstractions.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ end
297297
After a `@synchronize` statement all read and writes to global and local memory
298298
from each thread in the workgroup are visible in from all other threads in the
299299
workgroup.
300+
301+
!!! note
302+
`@synchronize()` must be encountered by all workitems of a work-group executing the kernel or by none at all.
300303
"""
301304
macro synchronize()
302305
return quote
@@ -314,10 +317,15 @@ workgroup. `cond` is not allowed to have any visible sideffects.
314317
# Platform differences
315318
- `GPU`: This synchronization will only occur if the `cond` evaluates.
316319
- `CPU`: This synchronization will always occur.
320+
321+
!!! warn
322+
This variant of the `@synchronize` macro violates the requirement that `@synchronize` must be encountered
323+
by all workitems of a work-group executing the kernel or by none at all.
324+
Since v`0.9.34` this version of the macro is deprecated and lowers to `@synchronize()`
317325
"""
318326
macro synchronize(cond)
319327
return quote
320-
$(esc(cond)) && $__synchronize()
328+
$__synchronize()
321329
end
322330
end
323331

src/macros.jl

+153-10
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,165 @@ function transform_gpu!(def, constargs, force_inbounds)
5858
end
5959
end
6060
pushfirst!(def[:args], :__ctx__)
61-
body = def[:body]
61+
new_stmts = Expr[]
62+
body = MacroTools.flatten(def[:body])
63+
stmts = body.args
64+
push!(new_stmts, Expr(:aliasscope))
65+
push!(new_stmts, :(__active_lane__ = $__validindex(__ctx__)))
6266
if force_inbounds
63-
body = quote
64-
@inbounds $(body)
65-
end
67+
push!(new_stmts, Expr(:inbounds, true))
6668
end
67-
body = quote
68-
if $__validindex(__ctx__)
69-
$(body)
70-
end
71-
return nothing
69+
append!(new_stmts, split(emit_gpu, body.args))
70+
if force_inbounds
71+
push!(new_stmts, Expr(:inbounds, :pop))
7272
end
73+
push!(new_stmts, Expr(:popaliasscope))
74+
push!(new_stmts, :(return nothing))
7375
def[:body] = Expr(
7476
:let,
7577
Expr(:block, let_constargs...),
76-
body,
78+
Expr(:block, new_stmts...),
7779
)
7880
return
7981
end
82+
83+
struct WorkgroupLoop
84+
indicies::Vector{Any}
85+
stmts::Vector{Any}
86+
allocations::Vector{Any}
87+
private_allocations::Vector{Any}
88+
private::Set{Symbol}
89+
end
90+
91+
is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_))
92+
93+
function is_scope_construct(expr::Expr)
94+
return expr.head === :block # ||
95+
# expr.head === :let
96+
end
97+
98+
function find_sync(stmt)
99+
result = false
100+
postwalk(stmt) do expr
101+
result |= is_sync(expr)
102+
expr
103+
end
104+
return result
105+
end
106+
107+
# TODO proper handling of LineInfo
108+
function split(
109+
emit,
110+
stmts,
111+
indicies = Any[], private = Set{Symbol}(),
112+
)
113+
# 1. Split the code into blocks separated by `@synchronize`
114+
# 2. Aggregate `@index` expressions
115+
# 3. Hoist allocations
116+
# 4. Hoist uniforms
117+
118+
current = Any[]
119+
allocations = Any[]
120+
private_allocations = Any[]
121+
new_stmts = Any[]
122+
for stmt in stmts
123+
has_sync = find_sync(stmt)
124+
if has_sync
125+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private))
126+
push!(new_stmts, emit(loop))
127+
allocations = Any[]
128+
private_allocations = Any[]
129+
current = Any[]
130+
is_sync(stmt) && continue
131+
132+
# Recurse into scope constructs
133+
# TODO: This currently implements hard scoping
134+
# probably need to implemet soft scoping
135+
# by not deepcopying the environment.
136+
recurse(x) = x
137+
function recurse(expr::Expr)
138+
expr = unblock(expr)
139+
if is_scope_construct(expr) && any(find_sync, expr.args)
140+
new_args = unblock(split(emit, expr.args, deepcopy(indicies), deepcopy(private)))
141+
return Expr(expr.head, new_args...)
142+
else
143+
return Expr(expr.head, map(recurse, expr.args)...)
144+
end
145+
end
146+
push!(new_stmts, recurse(stmt))
147+
continue
148+
end
149+
150+
if @capture(stmt, @uniform x_)
151+
push!(allocations, stmt)
152+
continue
153+
elseif @capture(stmt, @private lhs_ = rhs_)
154+
push!(private, lhs)
155+
push!(private_allocations, :($lhs = $rhs))
156+
continue
157+
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
158+
if @capture(rhs, @index(args__))
159+
push!(indicies, stmt)
160+
continue
161+
elseif @capture(rhs, @localmem(args__) | @uniform(args__))
162+
push!(allocations, stmt)
163+
continue
164+
elseif @capture(rhs, @private(T_, dims_))
165+
# Implement the legacy `mem = @private T dims` as
166+
# mem = Scratchpad(T, Val(dims))
167+
168+
if dims isa Integer
169+
dims = (dims,)
170+
end
171+
alloc = :($Scratchpad(__ctx__, $T, Val($dims)))
172+
push!(allocations, :($lhs = $alloc))
173+
push!(private, lhs)
174+
continue
175+
end
176+
end
177+
178+
push!(current, stmt)
179+
end
180+
181+
# everything since the last `@synchronize`
182+
if !isempty(current)
183+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private))
184+
push!(new_stmts, emit(loop))
185+
end
186+
return new_stmts
187+
end
188+
189+
function emit_gpu(loop)
190+
stmts = Any[]
191+
append!(stmts, loop.allocations)
192+
for stmt in loop.private_allocations
193+
if @capture(stmt, lhs_ = rhs_)
194+
push!(stmts, :($lhs = $rhs))
195+
else
196+
error("@private $stmt not an assignment")
197+
end
198+
end
199+
200+
# don't emit empty loops
201+
if !(isempty(loop.stmts) || all(s -> s isa LineNumberNode, loop.stmts))
202+
body = Expr(:block, loop.stmts...)
203+
body = postwalk(body) do expr
204+
if @capture(expr, lhs_ = rhs_)
205+
if lhs in loop.private
206+
error("Can't assign to variables marked private")
207+
end
208+
end
209+
return expr
210+
end
211+
loopexpr = quote
212+
if __active_lane__
213+
$(loop.indicies...)
214+
$(unblock(body))
215+
end
216+
$__synchronize()
217+
end
218+
push!(stmts, loopexpr)
219+
end
220+
221+
return unblock(Expr(:block, stmts...))
222+
end

0 commit comments

Comments
 (0)