@@ -37,6 +37,13 @@ using ..CUDA: i32
3737 (eq && a′ == b′) || lt (a′, b′)
3838end
3939
40+ # To allow sorting tuples of numbers:
41+ @inline _zero (x) = Base. zero (x)
42+ @inline _zero (:: Type{T} ) where {T<: Tuple{Vararg{Any,N}} } where {N} = ntuple (i -> zero (T. parameters[i]), N)
43+
44+ @inline _one (x) = Base. one (x)
45+ @inline _one (:: Type{T} ) where {T<: Tuple{Vararg{Any,N}} } where {N} = ntuple (i -> one (T. parameters[i]), N)
46+
4047
4148# Batch partitioning
4249"""
@@ -73,7 +80,7 @@ Uses block y index to decide which values to operate on.
7380 sync_threads ()
7481 blockIdx_yz = (blockIdx (). z - 1 i32) * gridDim (). y + blockIdx (). y
7582 idx0 = lo + (blockIdx_yz - 1 i32) * blockDim (). x + threadIdx (). x
76- val = idx0 <= hi ? values[idx0] : one (eltype (values))
83+ val = idx0 <= hi ? values[idx0] : _one (eltype (values))
7784 comparison = flex_lt (pivot, val, parity, lt, by)
7885
7986 @inbounds if idx0 <= hi
@@ -183,7 +190,7 @@ Must only run on 1 SM.
183190 swap = if threadIdx (). x <= to_move
184191 vals[lo + a + threadIdx (). x]
185192 else
186- zero (eltype (vals)) # unused value
193+ _zero (eltype (vals)) # unused value
187194 end
188195 sync_threads ()
189196 if threadIdx (). x <= to_move
@@ -215,7 +222,7 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
215222
216223 @inbounds swap[threadIdx (). x] = vals[lo + threadIdx (). x * stride]
217224 sync_threads ()
218- old_val = zero (eltype (swap))
225+ old_val = _zero (eltype (swap))
219226
220227 log_blockDim = begin
221228 out = 0
@@ -272,7 +279,7 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
272279 buddy_val = if 1 <= buddy <= L && threadIdx (). x <= L
273280 swap[buddy]
274281 else
275- zero (eltype (swap)) # unused value
282+ _zero (eltype (swap)) # unused value
276283 end
277284 sync_threads ()
278285 if 1 <= buddy <= L && threadIdx (). x <= L
0 commit comments