Skip to content

Commit ff27540

Browse files
committed
implement bitonic sorting network for SVectors
1 parent 95f2578 commit ff27540

File tree

4 files changed

+99
-0
lines changed

4 files changed

+99
-0
lines changed

src/StaticArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ include("abstractarray.jl")
123123
include("indexing.jl")
124124
include("broadcast.jl")
125125
include("mapreduce.jl")
126+
include("sort.jl")
126127
include("arraymath.jl")
127128
include("linalg.jl")
128129
include("matrix_multiply.jl")

src/sort.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import Base.Order: Ordering, Forward, ReverseOrdering, ord
2+
import Base.Sort: Algorithm, lt, sort
3+
4+
5+
struct BitonicSortAlg <: Algorithm end
6+
7+
const BitonicSort = BitonicSortAlg()
8+
9+
10+
# BitonicSort has non-optimal asymptotic behaviour, so we define a cutoff
11+
# length. This also prevents compilation time to skyrocket for larger vectors.
12+
defalg(a::StaticVector) =
13+
isimmutable(a) && length(a) <= 20 ? BitonicSort : QuickSort
14+
15+
@inline function sort(a::StaticVector;
16+
alg::Algorithm = defalg(a),
17+
lt = isless,
18+
by = identity,
19+
rev::Union{Bool,Nothing} = nothing,
20+
order::Ordering = Forward)
21+
length(a) <= 1 && return a
22+
ordr = ord(lt, by, rev, order)
23+
return _sort(a, alg, ordr)
24+
end
25+
26+
@inline _sort(a::StaticVector, alg, order) =
27+
similar_type(a)(sort!(Base.copymutable(a); alg=alg, order=order))
28+
29+
@inline _sort(a::StaticVector, alg::BitonicSortAlg, order) =
30+
similar_type(a)(_sort(Tuple(a), alg, order))
31+
32+
# Implementation loosely following
33+
# https://www.inf.hs-flensburg.de/lang/algorithmen/sortieren/bitonic/oddn.htm
34+
@generated function _sort(a::NTuple{N}, ::BitonicSortAlg, order) where N
35+
function swap_expr(i, j, rev)
36+
ai = Symbol('a', i)
37+
aj = Symbol('a', j)
38+
order = rev ? :revorder : :order
39+
return :( ($ai, $aj) = lt($order, $ai, $aj) ? ($ai, $aj) : ($aj, $ai) )
40+
end
41+
42+
function merge_exprs(idx, rev)
43+
exprs = Expr[]
44+
length(idx) == 1 && return exprs
45+
46+
ci = 2^(ceil(Int, log2(length(idx))) - 1)
47+
# TODO: generate simd code for these swaps
48+
for i in first(idx):last(idx)-ci
49+
push!(exprs, swap_expr(i, i+ci, rev))
50+
end
51+
append!(exprs, merge_exprs(idx[1:ci], rev))
52+
append!(exprs, merge_exprs(idx[ci+1:end], rev))
53+
return exprs
54+
end
55+
56+
function sort_exprs(idx, rev=false)
57+
exprs = Expr[]
58+
length(idx) == 1 && return exprs
59+
60+
append!(exprs, sort_exprs(idx[1:end÷2], !rev))
61+
append!(exprs, sort_exprs(idx[end÷2+1:end], rev))
62+
append!(exprs, merge_exprs(idx, rev))
63+
return exprs
64+
end
65+
66+
idx = 1:N
67+
symlist = (Symbol('a', i) for i in idx)
68+
return quote
69+
@_inline_meta
70+
revorder = Base.Order.ReverseOrdering(order)
71+
($(symlist...),) = a
72+
($(sort_exprs(idx)...);)
73+
return ($(symlist...),)
74+
end
75+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ include("abstractarray.jl")
3333
include("indexing.jl")
3434
include("initializers.jl")
3535
Random.seed!(42); include("mapreduce.jl")
36+
Random.seed!(42); include("sort.jl")
3637
Random.seed!(42); include("accumulate.jl")
3738
Random.seed!(42); include("arraymath.jl")
3839
include("broadcast.jl")

test/sort.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using StaticArrays, Test
2+
3+
@testset "sort" begin
4+
5+
@testset "basics" for T in (Int, Float64)
6+
for N in (0, 1, 2, 3, 10, 20, 30)
7+
vs = rand(SVector{N,T})
8+
vm = MVector{N,T}(vs)
9+
vref = sort(Vector(vs))
10+
11+
@test @inferred(sort(vs)) isa SVector
12+
@test @inferred(sort(vs, alg=QuickSort)) isa SVector
13+
@test @inferred(sort(vm)) isa MVector
14+
@test vref == sort(vs)
15+
@test vref == sort(vm)
16+
17+
# @allocated seems broken since 1.4
18+
#N <= 20 && @test 0 == @allocated sort(vs)
19+
end
20+
end
21+
22+
end

0 commit comments

Comments
 (0)