diff --git a/Project.toml b/Project.toml index f45f63b..c6dfd55 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SentinelArrays" uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" authors = ["Jacob Quinn "] -version = "1.2.10" +version = "1.2.11" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/src/chainedvector.jl b/src/chainedvector.jl index 7b6beb2..20be359 100644 --- a/src/chainedvector.jl +++ b/src/chainedvector.jl @@ -19,6 +19,7 @@ function ChainedVector(arrays::Vector{A}) where {A <: AbstractVector{T}} where { inds = Vector{Int}(undef, n) x = 0 @inbounds for i = 1:n + # note that arrays[i] can have zero length x += length(arrays[i]) inds[i] = x end @@ -50,18 +51,28 @@ end # efficient iteration @inline function Base.iterate(A::ChainedVector) length(A) == 0 && return nothing - i = 2 + i = 1 chunk = 1 chunk_i = 1 chunk_len = A.inds[1] - if i > chunk_len + while i > chunk_len chunk += 1 - chunk_i = 1 - @inbounds chunk_len = A.inds[min(length(A.inds), chunk)] + @inbounds chunk_len = A.inds[chunk] + end + x = A.arrays[chunk][1] + # find next valid index + i += 1 + if i > chunk_len + while true + chunk += 1 + chunk > length(A.inds) && break + @inbounds chunk_len = A.inds[chunk] + i <= chunk_len && break + end else chunk_i += 1 end - return A.arrays[1][1], (i, chunk, chunk_i, chunk_len, length(A)) + return x, (i, chunk, chunk_i, chunk_len, length(A)) end @inline function Base.iterate(A::ChainedVector, (i, chunk, chunk_i, chunk_len, len)) @@ -69,9 +80,13 @@ end @inbounds x = A.arrays[chunk][chunk_i] i += 1 if i > chunk_len - chunk += 1 chunk_i = 1 - @inbounds chunk_len = A.inds[min(length(A.inds), chunk)] + while true + chunk += 1 + chunk > length(A.inds) && break + @inbounds chunk_len = A.inds[chunk] + i <= chunk_len && break + end else chunk_i += 1 end diff --git a/test/runtests.jl b/test/runtests.jl index 8a3b3e9..7f00faa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -542,4 +542,48 @@ c2 = copy(c) deleteat!(c2, Int[]) @test length(c2) == 15 +@testset "iteration protocol on ChainedVector" begin + for len in 0:6 + cv = ChainedVector([1:len]) + @test length(cv) == len + c = 0 + for (i, v) in enumerate(cv) + c += 1 + @test i == v + end + @test c == len + for j in 0:len + cv = ChainedVector([1:j, j+1:len]) + @test length(cv) == len + c = 0 + for (i, v) in enumerate(cv) + c += 1 + @test i == v + end + @test c == len + for k in j:len + cv = ChainedVector([1:j, j+1:k, k+1:len]) + @test length(cv) == len + c = 0 + for (i, v) in enumerate(cv) + c += 1 + @test i == v + end + @test c == len + + for l in k:len + cv = ChainedVector([1:j, j+1:k, k+1:l, l+1:len]) + @test length(cv) == len + c = 0 + for (i, v) in enumerate(cv) + c += 1 + @test i == v + end + @test c == len + end + end + end + end +end + end