Skip to content

Support indexing involving end #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 3, 2019
Merged
2 changes: 1 addition & 1 deletion src/Setfield.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__precompile__(true)
module Setfield
using MacroTools
using MacroTools: isstructdef, splitstructdef
using MacroTools: isstructdef, splitstructdef, postwalk

include("lens.jl")
include("sugar.jl")
Expand Down
9 changes: 9 additions & 0 deletions src/lens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,15 @@ Base.@propagate_inbounds set(obj, ::ConstIndexLens{I}, val) where I =
end
end

struct DynamicIndexLens{F} <: Lens
f::F
end

Base.@propagate_inbounds get(obj, I::DynamicIndexLens) = obj[I.f(obj)...]

Base.@propagate_inbounds set(obj, I::DynamicIndexLens, val) =
setindex(obj, val, I.f(obj)...)

"""
FunctionLens(f)
@lens f(_)
Expand Down
30 changes: 30 additions & 0 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,30 @@ end

is_interpolation(x) = x isa Expr && x.head == :$

foldtree(op, init, x) = op(init, x)
foldtree(op, init, ex::Expr) =
op(foldl((acc, x) -> foldtree(op, acc, x), ex.args; init=init), ex)

need_dynamic_lens(ex) =
foldtree(false, ex) do yes, x
yes || x === :end || x === :_
end

replace_underscore(ex, to) = postwalk(x -> x === :_ ? to : x, ex)

function lower_index(collection::Symbol, index, dim)
if isexpr(index, :call)
return Expr(:call, lower_index.(collection, index.args, dim)...)
elseif index === :end
if dim === nothing
return :($(Base.lastindex)($collection))
else
return :($(Base.lastindex)($collection, $dim))
end
end
return index
end

function parse_obj_lenses(ex)
if @capture(ex, front_[indices__])
obj, frontlens = parse_obj_lenses(front)
Expand All @@ -63,6 +87,12 @@ function parse_obj_lenses(ex)
end
index = esc(Expr(:tuple, [x.args[1] for x in indices]...))
lens = :(ConstIndexLens{$index}())
elseif any(need_dynamic_lens, indices)
@gensym collection
indices = replace_underscore.(indices, collection)
dims = length(indices) == 1 ? nothing : 1:length(indices)
lindices = esc.(lower_index.(collection, indices, dims))
lens = :(DynamicIndexLens($(esc(collection)) -> ($(lindices...),)))
else
index = esc(Expr(:tuple, indices...))
lens = :(IndexLens($index))
Expand Down
34 changes: 34 additions & 0 deletions test/test_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ end
i = 1
si = @set t.a[i] = 10
@test s1 === si
se = @set t.a[end] = 20
@test se === T((1,20),(3,4))
se1 = @set t.a[end-1] = 10
@test s1 === se1

s1 = @set t.a[$1] = 10
@test s1 === T((10,2),(3,4))
Expand Down Expand Up @@ -191,6 +195,8 @@ end
@lens _.b.a.b[i]
@lens _.b.a.b[$2]
@lens _.b.a.b[$i]
@lens _.b.a.b[end]
@lens _.b.a.b[identity(end) - 1]
@lens _
]
val1, val2 = randn(2)
Expand Down Expand Up @@ -226,6 +232,8 @@ end
((@lens _.b.a.b[$(i+1)]), 4 ),
((@lens _.b.a.b[$2] ), 4.0),
((@lens _.b.a.b[$(i+1)]), 4.0),
((@lens _.b.a.b[end]), 4.0),
((@lens _.b.a.b[end÷2+1]), 4.0),
((@lens _ ), obj),
((@lens _ ), :xy),
(MultiPropertyLens((a=(@lens _), b=(@lens _))), (a=1, b=2)),
Expand All @@ -238,25 +246,51 @@ end

@testset "IndexLens" begin
l = @lens _[]
@test l isa Setfield.IndexLens
x = randn()
obj = Ref(x)
@test get(obj, l) == x

l = @lens _[][]
@test l.outer isa Setfield.IndexLens
@test l.inner isa Setfield.IndexLens
inner = Ref(x)
obj = Base.RefValue{typeof(inner)}(inner)
@test get(obj, l) == x

obj = (1,2,3)
l = @lens _[1]
@test l isa Setfield.IndexLens
@test get(obj, l) == 1
@test set(obj, l, 6) == (6,2,3)


l = @lens _[1:3]
@test l isa Setfield.IndexLens
@test get([4,5,6,7], l) == [4,5,6]
end

@testset "DynamicIndexLens" begin
l = @lens _[end]
@test l isa Setfield.DynamicIndexLens
obj = (1,2,3)
@test get(obj, l) == 3
@test set(obj, l, true) == (1,2,true)

l = @lens _[end÷2]
@test l isa Setfield.DynamicIndexLens
obj = (1,2,3)
@test get(obj, l) == 1
@test set(obj, l, true) == (true,2,3)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some more complicated lenses? Composed and multi index? Also can you add show tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I added a few more tests. I completed forgot about multi-index version so that's implemented/tested as well.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. @lens _[end, end]


two = 2
plusone(x) = x + 1
l = @lens _.a[plusone(end) - two].b
obj = (a=(1, (a=10, b=20), 3), b=4)
@test get(obj, l) == 20
@test set(obj, l, true) == (a=(1, (a=10, b=true), 3), b=4)
end

@testset "ConstIndexLens" begin
obj = (1, 2.0, '3')
l = @lens _[$1]
Expand Down
20 changes: 20 additions & 0 deletions test/test_staticarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,25 @@ using StaticArrays
v = @SVector [1,2,3]
@test (@set v[1] = 10) === @SVector [10,2,3]
@test_broken (@set v[1] = π) === @SVector [π,2,3]

@testset "Multi-dynamic indexing" begin
two = 2
plusone(x) = x + 1
l1 = @lens _.a[2, 1].b
l2 = @lens _.a[plusone(end) - two, end÷2].b
m_orig = @SMatrix [
(a=1, b=10) (a=2, b=20)
(a=3, b=30) (a=4, b=40)
(a=5, b=50) (a=6, b=60)
]
m_mod = @SMatrix [
(a=1, b=10) (a=2, b=20)
(a=3, b=3000) (a=4, b=40)
(a=5, b=50) (a=6, b=60)
]
obj = (a=m_orig, b=4)
@test get(obj, l1) === get(obj, l2) === 30
@test set(obj, l1, 3000) === set(obj, l2, 3000) === (a=m_mod, b=4)
end
end
end