@@ -37,13 +37,6 @@ using ..CUDA: i32
3737 (eq && a′ == b′) || lt (a′, b′)
3838end
3939
40- # To allow sorting tuples of numbers:
41- @inline _zero (x) = Base. zero (x)
42- @inline _zero (:: Type{T} ) where {T<: Tuple{Vararg{Any,N}} } where {N} = ntuple (i -> zero (T. parameters[i]), N)
43-
44- @inline _one (x) = Base. one (x)
45- @inline _one (:: Type{T} ) where {T<: Tuple{Vararg{Any,N}} } where {N} = ntuple (i -> one (T. parameters[i]), N)
46-
4740
4841# Batch partitioning
4942"""
@@ -80,7 +73,12 @@ Uses block y index to decide which values to operate on.
8073 sync_threads ()
8174 blockIdx_yz = (blockIdx (). z - 1 i32) * gridDim (). y + blockIdx (). y
8275 idx0 = lo + (blockIdx_yz - 1 i32) * blockDim (). x + threadIdx (). x
83- val = idx0 <= hi ? values[idx0] : _one (eltype (values))
76+ val = if idx0 <= hi
77+ values[idx0]
78+ else
79+ Ref {eltype(values)} ()[] # undef
80+ # if idx0 > hi, val, comparison and dest_idx are unused
81+ end
8482 comparison = flex_lt (pivot, val, parity, lt, by)
8583
8684 @inbounds if idx0 <= hi
@@ -190,7 +188,7 @@ Must only run on 1 SM.
190188 swap = if threadIdx (). x <= to_move
191189 vals[lo + a + threadIdx (). x]
192190 else
193- _zero ( eltype (vals)) # unused value
191+ Ref { eltype(vals)} ()[] # undef
194192 end
195193 sync_threads ()
196194 if threadIdx (). x <= to_move
@@ -222,7 +220,6 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
222220
223221 @inbounds swap[threadIdx (). x] = vals[lo + threadIdx (). x * stride]
224222 sync_threads ()
225- old_val = _zero (eltype (swap))
226223
227224 log_blockDim = begin
228225 out = 0
@@ -245,8 +242,10 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
245242 to_swap = (i & k) == 0 && bitonic_lt (l, i) || (i & k) != 0 && bitonic_lt (i, l)
246243 to_swap = to_swap == (i < l)
247244
248- if to_swap
249- @inbounds old_val = swap[l + 1 ]
245+ old_val = if to_swap
246+ @inbounds swap[l + 1 ]
247+ else
248+ Ref {eltype(swap)} ()[] # undef
250249 end
251250 sync_threads ()
252251 if to_swap
@@ -279,7 +278,7 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
279278 buddy_val = if 1 <= buddy <= L && threadIdx (). x <= L
280279 swap[buddy]
281280 else
282- _zero ( eltype (swap)) # unused value
281+ Ref { eltype(swap)} ()[] # undef
283282 end
284283 sync_threads ()
285284 if 1 <= buddy <= L && threadIdx (). x <= L
0 commit comments