Skip to content

Commit 506fbdf

Browse files
authored
add init keyword argument to count() (#37461)
1 parent d5ad85a commit 506fbdf

File tree

7 files changed

+51
-21
lines changed

7 files changed

+51
-21
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ New library features
117117
* The postfix operator `'ᵀ` can now be used as an alias for `transpose` ([#38043]).
118118
* `keys(io::IO)` has been added, which returns all keys of `io` if `io` is an `IOContext` and an empty
119119
`Base.KeySet` otherwise ([#37753]).
120+
* `count` now accepts an optional `init` argument to control the accumulation type ([#37461]).
120121

121122
Standard library changes
122123
------------------------

base/bitarray.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,15 +1386,15 @@ circshift!(B::BitVector, i::Integer) = circshift!(B, B, i)
13861386

13871387
## count & find ##
13881388

1389-
function bitcount(Bc::Vector{UInt64})
1390-
n = 0
1389+
function bitcount(Bc::Vector{UInt64}; init::T=0) where {T}
1390+
n::T = init
13911391
@inbounds for i = 1:length(Bc)
1392-
n += count_ones(Bc[i])
1392+
n = (n + count_ones(Bc[i])) % T
13931393
end
13941394
return n
13951395
end
13961396

1397-
count(B::BitArray) = bitcount(B.chunks)
1397+
count(B::BitArray; init=0) = bitcount(B.chunks; init)
13981398

13991399
function unsafe_bitfindnext(Bc::Vector{UInt64}, start::Int)
14001400
chunk_start = _div64(start-1)+1

base/reduce.jl

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -938,12 +938,15 @@ end
938938
_bool(f) = x->f(x)::Bool
939939

940940
"""
941-
count(p, itr) -> Integer
942-
count(itr) -> Integer
941+
count([f=identity,] itr; init=0) -> Integer
943942
944-
Count the number of elements in `itr` for which predicate `p` returns `true`.
945-
If `p` is omitted, counts the number of `true` elements in `itr` (which
946-
should be a collection of boolean values).
943+
Count the number of elements in `itr` for which the function `f` returns `true`.
944+
If `f` is omitted, count the number of `true` elements in `itr` (which
945+
should be a collection of boolean values). `init` optionally specifies the value
946+
to start counting from and therefore also determines the output type.
947+
948+
!!! compat "Julia 1.6"
949+
`init` keyword was added in Julia 1.6.
947950
948951
# Examples
949952
```jldoctest
@@ -952,32 +955,35 @@ julia> count(i->(4<=i<=6), [2,3,4,5,6])
952955
953956
julia> count([true, false, true, true])
954957
3
958+
959+
julia> count(>(3), 1:7, init=0x03)
960+
0x07
955961
```
956962
"""
957-
count(itr) = count(identity, itr)
963+
count(itr; init=0) = count(identity, itr; init)
958964

959-
count(f, itr) = _simple_count(f, itr)
965+
count(f, itr; init=0) = _simple_count(f, itr, init)
960966

961-
function _simple_count(pred, itr)
962-
n = 0
967+
function _simple_count(pred, itr, init::T) where {T}
968+
n::T = init
963969
for x in itr
964970
n += pred(x)::Bool
965971
end
966972
return n
967973
end
968974

969-
function count(::typeof(identity), x::Array{Bool})
970-
n = 0
975+
function _simple_count(::typeof(identity), x::Array{Bool}, init::T=0) where {T}
976+
n::T = init
971977
chunks = length(x) ÷ sizeof(UInt)
972978
mask = 0x0101010101010101 % UInt
973979
GC.@preserve x begin
974980
ptr = Ptr{UInt}(pointer(x))
975981
for i in 1:chunks
976-
n += count_ones(unsafe_load(ptr, i) & mask)
982+
n = (n + count_ones(unsafe_load(ptr, i) & mask)) % T
977983
end
978984
end
979985
for i in sizeof(UInt)*chunks+1:length(x)
980-
n += x[i]
986+
n = (n + x[i]) % T
981987
end
982988
return n
983989
end

base/reducedim.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,9 @@ dimensions.
369369
!!! compat "Julia 1.5"
370370
`dims` keyword was added in Julia 1.5.
371371
372+
!!! compat "Julia 1.6"
373+
`init` keyword was added in Julia 1.6.
374+
372375
# Examples
373376
```jldoctest
374377
julia> A = [1 2; 3 4]
@@ -386,11 +389,11 @@ julia> count(<=(2), A, dims=2)
386389
0
387390
```
388391
"""
389-
count(A::AbstractArrayOrBroadcasted; dims=:) = count(identity, A, dims=dims)
390-
count(f, A::AbstractArrayOrBroadcasted; dims=:) = _count(f, A, dims)
392+
count(A::AbstractArrayOrBroadcasted; dims=:, init=0) = count(identity, A; dims, init)
393+
count(f, A::AbstractArrayOrBroadcasted; dims=:, init=0) = _count(f, A, dims, init)
391394

392-
_count(f, A::AbstractArrayOrBroadcasted, dims::Colon) = _simple_count(f, A)
393-
_count(f, A::AbstractArrayOrBroadcasted, dims) = mapreduce(_bool(f), add_sum, A, dims=dims, init=0)
395+
_count(f, A::AbstractArrayOrBroadcasted, dims::Colon, init) = _simple_count(f, A, init)
396+
_count(f, A::AbstractArrayOrBroadcasted, dims, init) = mapreduce(_bool(f), add_sum, A; dims, init)
394397

395398
"""
396399
count!([f=identity,] r, A)

test/bitarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,8 @@ timesofar("datamove")
12191219
@check_bit_operation findall(falses(t)) ret_type
12201220
@check_bit_operation findall(bitrand(t)) ret_type
12211221
end
1222+
1223+
@test count(trues(2, 2), init=0x03) === 0x07
12221224
end
12231225

12241226
timesofar("find")

test/reduce.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,12 @@ struct NonFunctionIsZero end
520520
@test count(NonFunctionIsZero(), [0]) == 1
521521
@test count(NonFunctionIsZero(), [1]) == 0
522522

523+
@test count(Iterators.repeated(true, 3), init=0x04) === 0x07
524+
@test count(!=(2), Iterators.take(1:7, 3), init=Int32(0)) === Int32(2)
525+
@test count(identity, [true, false], init=Int8(5)) === Int8(6)
526+
@test count(!, [true false; false true], dims=:, init=Int16(0)) === Int16(2)
527+
@test isequal(count(identity, [true false; false true], dims=2, init=UInt(4)), reshape(UInt[5, 5], 2, 1))
528+
523529
## cumsum, cummin, cummax
524530

525531
z = rand(10^6)

test/reducedim.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
7777
@test @inferred(maximum(abs, Areduc, dims=region)) safe_maxabs(Areduc, region)
7878
@test @inferred(minimum(abs, Areduc, dims=region)) safe_minabs(Areduc, region)
7979
@test @inferred(count(!, Breduc, dims=region)) safe_count(.!Breduc, region)
80+
81+
@test isequal(
82+
@inferred(count(Breduc, dims=region, init=0x02)),
83+
safe_count(Breduc, region) .% UInt8 .+ 0x02,
84+
)
85+
@test isequal(
86+
@inferred(count(!, Breduc, dims=region, init=Int16(0))),
87+
safe_count(.!Breduc, region) .% Int16,
88+
)
8089
end
8190

8291
# Combining dims and init
@@ -446,3 +455,6 @@ end
446455
@test_throws TypeError count([1], dims=1)
447456
@test_throws TypeError count!([1], [1])
448457
end
458+
459+
@test @inferred(count(false:true, dims=:, init=0x0004)) === 0x0005
460+
@test @inferred(count(isodd, reshape(1:9, 3, 3), dims=:, init=Int128(0))) === Int128(5)

0 commit comments

Comments
 (0)