Skip to content

Commit 482af35

Browse files
committed
implement bitonic sorting network for SVectors
1 parent 95f2578 commit 482af35

File tree

4 files changed

+93
-0
lines changed

4 files changed

+93
-0
lines changed

src/StaticArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ include("abstractarray.jl")
123123
include("indexing.jl")
124124
include("broadcast.jl")
125125
include("mapreduce.jl")
126+
include("sort.jl")
126127
include("arraymath.jl")
127128
include("linalg.jl")
128129
include("matrix_multiply.jl")

src/sort.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import Base.@_inline_meta
2+
import Base.Order: Ordering, Forward, ReverseOrdering, ord
3+
import Base.Sort: Algorithm, defalg, lt, sort
4+
5+
6+
struct BitonicSortAlg <: Algorithm end
7+
8+
const BitonicSort = BitonicSortAlg()
9+
10+
# BitonicSort has non-optimal asymptotic behaviour, so we define a cutoff length.
11+
# This also prevents compilation time to skyrocket for larger vectors.
12+
defalg(a::StaticVector) = isimmutable(a) && length(a) <= 20 ? BitonicSort : QuickSort
13+
14+
@inline function sort(a::StaticVector;
15+
alg::Algorithm = defalg(a),
16+
lt = isless,
17+
by = identity,
18+
rev::Union{Bool,Nothing} = nothing,
19+
order::Ordering = Forward)
20+
length(a) <= 1 && return a
21+
ordr = ord(lt, by, rev, order)
22+
return _sort(Size(a), alg, ordr, a)
23+
end
24+
25+
@inline _sort(_, alg, order, a::StaticVector) = sort!(Base.copymutable(a); alg=alg, order=order)
26+
@inline _sort(_, alg::BitonicSortAlg, order, a::StaticVector) = similar_type(a)(_sort(Tuple(a), alg, order))
27+
28+
# Implementation loosely following
29+
# https://www.inf.hs-flensburg.de/lang/algorithmen/sortieren/bitonic/oddn.htm
30+
@generated function _sort(a::NTuple{N}, ::BitonicSortAlg, order) where N
31+
function swap_expr(i, j, rev)
32+
ai = Symbol('a', i)
33+
aj = Symbol('a', j)
34+
order = rev ? :revorder : :order
35+
return :( ($ai, $aj) = lt($order, $ai, $aj) ? ($ai, $aj) : ($aj, $ai) )
36+
end
37+
38+
function merge_exprs(idx, rev)
39+
exprs = Expr[]
40+
length(idx) == 1 && return exprs
41+
42+
ci = 2^(ceil(Int, log2(length(idx))) - 1)
43+
# TODO: generate simd code for these swaps
44+
for i in first(idx):last(idx)-ci
45+
push!(exprs, swap_expr(i, i+ci, rev))
46+
end
47+
append!(exprs, merge_exprs(idx[1:ci], rev))
48+
append!(exprs, merge_exprs(idx[ci+1:end], rev))
49+
return exprs
50+
end
51+
52+
function sort_exprs(idx, rev=false)
53+
exprs = Expr[]
54+
length(idx) == 1 && return exprs
55+
56+
append!(exprs, sort_exprs(idx[1:end÷2], !rev))
57+
append!(exprs, sort_exprs(idx[end÷2+1:end], rev))
58+
append!(exprs, merge_exprs(idx, rev))
59+
return exprs
60+
end
61+
62+
idx = 1:N
63+
symlist = (Symbol('a', i) for i in idx)
64+
return quote
65+
@_inline_meta
66+
revorder = Base.Order.ReverseOrdering(order)
67+
($(symlist...),) = a
68+
($(sort_exprs(idx)...);)
69+
return ($(symlist...),)
70+
end
71+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ include("abstractarray.jl")
3333
include("indexing.jl")
3434
include("initializers.jl")
3535
Random.seed!(42); include("mapreduce.jl")
36+
Random.seed!(42); include("sort.jl")
3637
Random.seed!(42); include("accumulate.jl")
3738
Random.seed!(42); include("arraymath.jl")
3839
include("broadcast.jl")

test/sort.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using StaticArrays, Test
2+
3+
@testset "sort" begin
4+
5+
@testset "basics" for T in (Int, Float64)
6+
for N in (0, 1, 2, 3, 10, 20)
7+
v = rand(SVector{N,T})
8+
vs = sort!(Base.copymutable(v))
9+
10+
@test vs == @inferred sort(v)
11+
@test 0 == @allocated sort(v)
12+
end
13+
end
14+
15+
@testset "fallbacks" begin
16+
@test @inferred(sort(rand(SVector{3}), alg=QuickSort)) isa MVector
17+
@test @inferred(sort(rand(SVector{21}))) isa MVector
18+
end
19+
20+
end

0 commit comments

Comments
 (0)