Skip to content

Commit 3e4b386

Browse files
committed
add init argument to count
1 parent 1e6d771 commit 3e4b386

File tree

6 files changed

+33
-20
lines changed

6 files changed

+33
-20
lines changed

base/bitarray.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,15 +1386,15 @@ circshift!(B::BitVector, i::Integer) = circshift!(B, B, i)
13861386

13871387
## count & find ##
13881388

1389-
function bitcount(Bc::Vector{UInt64})
1390-
n = 0
1389+
function bitcount(Bc::Vector{UInt64}; init::T=0) where {T}
1390+
n::T = init
13911391
@inbounds for i = 1:length(Bc)
1392-
n += count_ones(Bc[i])
1392+
n += count_ones(Bc[i]) % T
13931393
end
13941394
return n
13951395
end
13961396

1397-
count(B::BitArray) = bitcount(B.chunks)
1397+
count(B::BitArray; init=0) = bitcount(B.chunks; init)
13981398

13991399
function unsafe_bitfindnext(Bc::Vector{UInt64}, start::Int)
14001400
chunk_start = _div64(start-1)+1

base/reduce.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -938,11 +938,10 @@ end
938938
_bool(f) = x->f(x)::Bool
939939

940940
"""
941-
count(p, itr) -> Integer
942-
count(itr) -> Integer
941+
count([f=identity,] itr; init=0) -> Integer
943942
944-
Count the number of elements in `itr` for which predicate `p` returns `true`.
945-
If `p` is omitted, counts the number of `true` elements in `itr` (which
943+
Count the number of elements in `itr` for which the function `f` returns `true`.
944+
If `f` is omitted, counts the number of `true` elements in `itr` (which
946945
should be a collection of boolean values).
947946
948947
# Examples
@@ -954,30 +953,30 @@ julia> count([true, false, true, true])
954953
3
955954
```
956955
"""
957-
count(itr) = count(identity, itr)
956+
count(itr; init=0) = count(identity, itr; init)
958957

959-
count(f, itr) = _simple_count(f, itr)
958+
count(f, itr; init=0) = _simple_count(f, itr, init)
960959

961-
function _simple_count(pred, itr)
962-
n = 0
960+
function _simple_count(pred, itr, init)
961+
n = init
963962
for x in itr
964963
n += pred(x)::Bool
965964
end
966965
return n
967966
end
968967

969-
function count(::typeof(identity), x::Array{Bool})
970-
n = 0
968+
function _simple_count(::typeof(identity), x::Array{Bool}, init::T=0) where {T}
969+
n::T = init
971970
chunks = length(x) ÷ sizeof(UInt)
972971
mask = 0x0101010101010101 % UInt
973972
GC.@preserve x begin
974973
ptr = Ptr{UInt}(pointer(x))
975974
for i in 1:chunks
976-
n += count_ones(unsafe_load(ptr, i) & mask)
975+
n += count_ones(unsafe_load(ptr, i) & mask) % T
977976
end
978977
end
979978
for i in sizeof(UInt)*chunks+1:length(x)
980-
n += x[i]
979+
n += x[i] % T
981980
end
982981
return n
983982
end

base/reducedim.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,11 +386,11 @@ julia> count(<=(2), A, dims=2)
386386
0
387387
```
388388
"""
389-
count(A::AbstractArrayOrBroadcasted; dims=:) = count(identity, A, dims=dims)
390-
count(f, A::AbstractArrayOrBroadcasted; dims=:) = _count(f, A, dims)
389+
count(A::AbstractArrayOrBroadcasted; dims=:, init=0) = count(identity, A; dims, init)
390+
count(f, A::AbstractArrayOrBroadcasted; dims=:, init=0) = _count(f, A, dims, init)
391391

392-
_count(f, A::AbstractArrayOrBroadcasted, dims::Colon) = _simple_count(f, A)
393-
_count(f, A::AbstractArrayOrBroadcasted, dims) = mapreduce(_bool(f), add_sum, A, dims=dims, init=0)
392+
_count(f, A::AbstractArrayOrBroadcasted, dims::Colon, init) = _simple_count(f, A, init)
393+
_count(f, A::AbstractArrayOrBroadcasted, dims, init) = mapreduce(_bool(f), add_sum, A; dims, init)
394394

395395
"""
396396
count!([f=identity,] r, A)

test/bitarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,8 @@ timesofar("datamove")
12191219
@check_bit_operation findall(falses(t)) ret_type
12201220
@check_bit_operation findall(bitrand(t)) ret_type
12211221
end
1222+
1223+
@test count(trues(2, 2), init=0x00) isa UInt8
12221224
end
12231225

12241226
timesofar("find")

test/reduce.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,12 @@ struct NonFunctionIsZero end
520520
@test count(NonFunctionIsZero(), [0]) == 1
521521
@test count(NonFunctionIsZero(), [1]) == 0
522522

523+
@test count(Iterators.repeated(true, 3), init=0x00) isa UInt8
524+
@test count(!=(2), Iterators.take(1:7, 3), init=Int32(0)) isa Int32
525+
@test count(identity, [true, false], init=Int8(0)) isa Int8
526+
@test count(!, [true false; false true], dims=:, init=Int16(0)) isa Int16
527+
@test count(identity, [true false; false true], dims=2, init=UInt(0)) isa Matrix{UInt}
528+
523529
## cumsum, cummin, cummax
524530

525531
z = rand(10^6)

test/reducedim.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
7777
@test @inferred(maximum(abs, Areduc, dims=region)) safe_maxabs(Areduc, region)
7878
@test @inferred(minimum(abs, Areduc, dims=region)) safe_minabs(Areduc, region)
7979
@test @inferred(count(!, Breduc, dims=region)) safe_count(.!Breduc, region)
80+
81+
@test eltype(@inferred(count(Breduc, dims=region, init=0x00))) == UInt8
82+
@test eltype(@inferred(count(!, Breduc, dims=region, init=Int16(0)))) == Int16
8083
end
8184

8285
# Combining dims and init
@@ -446,3 +449,6 @@ end
446449
@test_throws TypeError count([1], dims=1)
447450
@test_throws TypeError count!([1], [1])
448451
end
452+
453+
@test @inferred(count(false:true, dims=:, init=0x0000)) isa UInt16
454+
@test @inferred(count(isodd, reshape(1:9, 3, 3), dims=:, init=Int128(0))) isa Int128

0 commit comments

Comments
 (0)