Skip to content

Commit ff693e2

Browse files
authored
Work around Julia's Base.Sort.MissingOptimization bugs (#78)
1 parent 4f1b96e commit ff693e2

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

src/SortingAlgorithms.jl

+25-4
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ struct TimSortAlg <: Algorithm end
1616
struct RadixSortAlg <: Algorithm end
1717
struct CombSortAlg <: Algorithm end
1818

19-
function maybe_optimize(x::Algorithm)
19+
function maybe_optimize(x::Algorithm)
2020
isdefined(Base.Sort, :InitialOptimizations) ? Base.Sort.InitialOptimizations(x) : x
21-
end
21+
end
2222
const HeapSort = maybe_optimize(HeapSortAlg())
2323
const TimSort = maybe_optimize(TimSortAlg())
24-
# Whenever InitialOptimizations is defined, RadixSort falls
24+
# Whenever InitialOptimizations is defined, RadixSort falls
2525
# back to Base.DEFAULT_STABLE which already includes them.
2626
const RadixSort = RadixSortAlg()
2727

@@ -79,6 +79,27 @@ end
7979
#
8080
# Original author: @kmsquire
8181

82+
@static if v"1.9.0-alpha" <= VERSION <= v"1.9.1"
83+
function Base.getindex(v::Base.Sort.WithoutMissingVector, i::UnitRange)
84+
out = Vector{eltype(v)}(undef, length(i))
85+
out .= v.data[i]
86+
out
87+
end
88+
89+
# skip MissingOptimization due to JuliaLang/julia#50171
90+
const _FIVE_ARG_SAFE_DEFAULT_STABLE = Base.DEFAULT_STABLE.next
91+
92+
# Explicitly define conversion from _sort!(v, alg, order, kw) to sort!(v, lo, hi, alg, order)
93+
# To avoid excessively strict dispatch loop detection
94+
function Base.Sort._sort!(v::AbstractVector, a::Union{HeapSortAlg, TimSortAlg, RadixSortAlg, CombSortAlg}, o::Base.Order.Ordering, kw)
95+
Base.Sort.@getkw lo hi scratch
96+
sort!(v, lo, hi, a, o)
97+
scratch
98+
end
99+
else
100+
const _FIVE_ARG_SAFE_DEFAULT_STABLE = Base.DEFAULT_STABLE
101+
end
102+
82103
const Run = UnitRange{Int}
83104

84105
const MIN_GALLOP = 7
@@ -490,7 +511,7 @@ function sort!(v::AbstractVector, lo::Int, hi::Int, ::TimSortAlg, o::Ordering)
490511
# Make a run of length minrun
491512
count = min(minrun, hi-i+1)
492513
run_range = i:i+count-1
493-
sort!(v, i, i+count-1, DEFAULT_STABLE, o)
514+
sort!(v, i, i+count-1, _FIVE_ARG_SAFE_DEFAULT_STABLE, o)
494515
else
495516
if !issorted(run_range)
496517
run_range = last(run_range):first(run_range)

test/runtests.jl

+25-4
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@ using StatsBase
44
using Random
55

66
a = rand(1:10000, 1000)
7+
am = [rand() < .9 ? i : missing for i in a]
78

8-
for alg in [TimSort, HeapSort, RadixSort, CombSort]
9+
for alg in [TimSort, HeapSort, RadixSort, CombSort, SortingAlgorithms.TimSortAlg()]
910
b = sort(a, alg=alg)
1011
@test issorted(b)
1112
ix = sortperm(a, alg=alg)
1213
b = a[ix]
1314
@test issorted(b)
1415
@test a[ix] == b
1516

17+
# legacy 3-argument calling convention
18+
@test b == sort!(copy(a), alg, Base.Order.Forward)
19+
1620
b = sort(a, alg=alg, rev=true)
1721
@test issorted(b, rev=true)
1822
ix = sortperm(a, alg=alg, rev=true)
@@ -34,9 +38,26 @@ for alg in [TimSort, HeapSort, RadixSort, CombSort]
3438
invpermute!(c, ix)
3539
@test c == a
3640

37-
if alg != RadixSort # RadixSort does not work with Lt orderings
41+
if alg != RadixSort # RadixSort does not work with Lt orderings or missing
3842
c = sort(a, alg=alg, lt=(>))
3943
@test b == c
44+
45+
# Issue https://github.com/JuliaData/DataFrames.jl/issues/3340
46+
bm1 = sort(am, alg=alg)
47+
@test issorted(bm1)
48+
@test count(ismissing, bm1) == count(ismissing, am)
49+
50+
bm2 = am[sortperm(am, alg=alg)]
51+
@test issorted(bm2)
52+
@test count(ismissing, bm2) == count(ismissing, am)
53+
54+
bm3 = am[sortperm!(collect(eachindex(am)), am, alg=alg)]
55+
@test issorted(bm3)
56+
@test count(ismissing, bm3) == count(ismissing, am)
57+
58+
if alg == TimSort # Stable
59+
@test all(bm1 .=== bm2 .=== bm3)
60+
end
4061
end
4162

4263
c = sort(a, alg=alg, by=x->1/x)
@@ -103,8 +124,8 @@ for n in [0:10..., 100, 101, 1000, 1001]
103124
# test float sorting with NaNs
104125
s = sort(v, alg=alg, order=ord)
105126
@test issorted(s, order=ord)
106-
107-
# This tests that NaNs (which compare equivalent) are treated stably
127+
128+
# This tests that NaNs (which compare equivalent) are treated stably
108129
# even when the underlying algorithm is unstable. That it happens to
109130
# pass is not a part of the public API:
110131
@test reinterpret(UInt64, v[map(isnan, v)]) == reinterpret(UInt64, s[map(isnan, s)])

0 commit comments

Comments
 (0)