Skip to content

Commit 2ca8a3e

Browse files
committed
implement bitonic sorting network for SVectors
1 parent 95f2578 commit 2ca8a3e

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

src/StaticArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ include("abstractarray.jl")
123123
include("indexing.jl")
124124
include("broadcast.jl")
125125
include("mapreduce.jl")
126+
include("sort.jl")
126127
include("arraymath.jl")
127128
include("linalg.jl")
128129
include("matrix_multiply.jl")

src/sort.jl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import Base.@_inline_meta
2+
import Base.Order: Ordering, Forward, ReverseOrdering, ord
3+
import Base.Sort: Algorithm, defalg, lt, sort
4+
5+
6+
struct BitonicSortAlg <: Algorithm end
7+
struct MinSizeSortAlg <: Algorithm end
8+
struct MinDepthSortAlg <: Algorithm end
9+
const MinSortAlg = Union{MinSizeSortAlg,MinDepthSortAlg}
10+
11+
const BitonicSort = BitonicSortAlg()
12+
const MinSizeSort = MinSizeSortAlg()
13+
const MinDepthSort = MinDepthSortAlg()
14+
15+
defalg(::SVector) = BitonicSort
16+
17+
function sort(a::SVector;
18+
alg::Algorithm = defalg(a),
19+
lt = isless,
20+
by = identity,
21+
rev::Union{Bool,Nothing} = nothing,
22+
order::Ordering = Forward)
23+
ordr = ord(lt, by, rev, order)
24+
_sort(Size(a), alg, ordr, a)
25+
end
26+
27+
_sort(::Size{T}, alg, _, _) where T =
28+
error("sorting algorithm $alg unimplemented for static array of size $T")
29+
30+
31+
@inline _cmpswap(order, a, b) = lt(order, a, b) ? (a, b) : (b, a)
32+
33+
@inline _sort(::Size{(1,)}, _, _, a) = a
34+
@inline _sort(::Size{(2,)}, _, order, (a1, a2)) = SVector(_cmpswap(order, a1, a2))
35+
36+
@inline _sort(::Size{(1,)}, ::BitonicSortAlg, _, a) = a
37+
@inline _sort(s::Size{(2,)}, ::BitonicSortAlg, order, (a1, a2)) = SVector(_cmpswap(order, a1, a2))
38+
@generated function _sort(::Size{S}, ::BitonicSortAlg, order, a) where {S}
39+
function swap_expr(i, j, dir)
40+
ai = Symbol('a', i)
41+
aj = Symbol('a', j)
42+
order = dir ? :revorder : :order
43+
return :( ($ai, $aj) = _cmpswap($order, $ai, $aj) )
44+
end
45+
46+
function merge_exprs(idx, dir)
47+
exprs = Expr[]
48+
length(idx) == 1 && return exprs
49+
50+
ci = 2^(ceil(Int, log2(length(idx))) - 1)
51+
# TODO: generate simd code for these swaps
52+
for i in first(idx):last(idx)-ci
53+
push!(exprs, swap_expr(i, i+ci, dir))
54+
end
55+
append!(exprs, merge_exprs(idx[1:ci], dir))
56+
append!(exprs, merge_exprs(idx[ci+1:end], dir))
57+
return exprs
58+
end
59+
60+
function sort_exprs(idx, dir)
61+
exprs = Expr[]
62+
length(idx) == 1 && return exprs
63+
64+
append!(exprs, sort_exprs(idx[1:end÷2], !dir))
65+
append!(exprs, sort_exprs(idx[end÷2+1:end], dir))
66+
append!(exprs, merge_exprs(idx, dir))
67+
return exprs
68+
end
69+
70+
idx = 1:prod(S)
71+
symlist = (Symbol('a', i) for i in idx)
72+
sym_exprs = (:( $ai = a[$i] ) for (i, ai) in enumerate(symlist))
73+
return quote
74+
@_inline_meta
75+
revorder = Base.Order.ReverseOrdering(order)
76+
@inbounds ($(sym_exprs...);)
77+
($(sort_exprs(idx, false)...);)
78+
return SVector(($(symlist...)))
79+
end
80+
end
81+
82+
83+
## TODO: manually implementing minimal sorting networks for small lengths might
84+
## be worthwhile
85+
#
86+
#macro _cmpswap(order, a, b)
87+
# return esc(:( ($a, $b) = _cmpswap(order, $a, $b) ))
88+
#end
89+
#
90+
#@inline function _sort(::Size{(3,)}, ::MinSortAlg, order, (a1, a2, a3))
91+
# @_cmpswap order a1 a3
92+
# @_cmpswap order a1 a2
93+
# @_cmpswap order a2 a3
94+
# return SVector(a1, a2, a3)
95+
#end
96+
#
97+
#@inline function _sort(::Size{(4,)}, ::MinSortAlg, order, (a1, a2, a3, a4))
98+
# @_cmpswap order a1 a3
99+
# @_cmpswap order a2 a4
100+
# @_cmpswap order a1 a2
101+
# @_cmpswap order a3 a4
102+
# @_cmpswap order a2 a3
103+
# return SVector(a1, a2, a3, a4)
104+
#end

0 commit comments

Comments
 (0)