Skip to content

Commit 9d35b63

Browse files
authored
Fix get function (#93)
* Fix `get` function * fix definition and tests * use `optic` instead of`.optic` * add tests for `set` * fix more test errors * fix doctest * version bump
1 parent cc9c8ed commit 9d35b63

File tree

3 files changed

+33
-27
lines changed

3 files changed

+33
-27
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probablistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.8.0"
6+
version = "0.8.1"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/varname.jl

+2-14
Original file line numberDiff line numberDiff line change
@@ -122,23 +122,11 @@ getoptic(vn::VarName) = vn.optic
122122
"""
123123
get(obj, vn::VarName{sym})
124124
125-
Alias for `getoptic(vn)(obj)`.
126-
127-
# Example
128-
129-
```jldoctest; setup = :(nt = (a = 1, b = (c = [1, 2, 3],)); name = :nt)
130-
julia> get(nt, @varname(nt.a))
131-
1
132-
133-
julia> get(nt, @varname(nt.b.c[1]))
134-
1
135-
136-
julia> get(nt, @varname(\$name.b.c[1]))
137-
1
125+
Alias for `(PropertyLens{sym}() ⨟ getoptic(vn))(obj)`.
138126
```
139127
"""
140128
function Base.get(obj, vn::VarName{sym}) where {sym}
141-
return getoptic(vn)(obj)
129+
return (PropertyLens{sym}() getoptic(vn))(obj)
142130
end
143131

144132
"""

test/varname.jl

+30-12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@ macro test_strict_subsumption(x, y)
1414
end
1515
end
1616

17+
function test_equal(o1::VarName{sym1}, o2::VarName{sym2}) where {sym1, sym2}
18+
return sym1 === sym2 && test_equal(o1.optic, o2.optic)
19+
end
20+
function test_equal(o1::ComposedFunction, o2::ComposedFunction)
21+
return test_equal(o1.inner, o2.inner) && test_equal(o1.outer, o2.outer)
22+
end
23+
function test_equal(o1::Accessors.IndexLens, o2::Accessors.IndexLens)
24+
return test_equal(o1.indices, o2.indices)
25+
end
26+
function test_equal(o1, o2)
27+
return o1 == o2
28+
end
29+
1730
@testset "varnames" begin
1831
@testset "construction & concretization" begin
1932
i = 1:10
@@ -27,14 +40,22 @@ end
2740

2841
# concretization
2942
y = zeros(10, 10)
30-
x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], );
43+
x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0],);
3144

3245
@test @varname(y[begin, i], true) == @varname(y[1, 1:10])
33-
@test get(y, @varname(y[:], true)) == get(y, @varname(y[1:100]))
34-
@test get(y, @varname(y[:, begin], true)) == get(y, @varname(y[1:10, 1]))
35-
@test getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] ===
46+
@test test_equal(@varname(y[:], true), @varname(y[1:100]))
47+
@test test_equal(@varname(y[:, begin], true), @varname(y[1:10, 1]))
48+
@test getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] ===
3649
AbstractPPL.ConcretizedSlice(to_indices(y, (:,))[1])
37-
@test get(x, @varname(x.a[1:end, end][:], true)) == get(x, @varname(x.a[1:3,2][1:3]))
50+
@test test_equal(@varname(x.a[1:end, end][:], true), @varname(x.a[1:3,2][1:3]))
51+
end
52+
53+
@testset "get & set" begin
54+
x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], b = 1.0);
55+
@test get(x, @varname(a[1, 2])) == 2.0
56+
@test get(x, @varname(b)) == 1.0
57+
@test set(x, @varname(a[1, 2]), 10) == (a = [1.0 10.0; 3.0 4.0; 5.0 6.0], b = 1.0)
58+
@test set(x, @varname(b), 10) == (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], b = 10.0)
3859
end
3960

4061
@testset "subsumption with standard indexing" begin
@@ -83,10 +104,10 @@ end
83104

84105
@testset "non-standard indexing" begin
85106
A = rand(10, 10)
86-
@test get(A, @varname(A[1, Not(3)], true)) == get(A, @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]]))
107+
@test test_equal(@varname(A[1, Not(3)], true), @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]]))
87108

88109
B = OffsetArray(A, -5, -5) # indices -4:5×-4:5
89-
@test collect(get(B, @varname(B[1, :], true))) == collect(get(B, @varname(B[1, -4:5])))
110+
@test test_equal(@varname(B[1, :], true), @varname(B[1, -4:5]))
90111

91112
end
92113
@testset "type stability" begin
@@ -96,15 +117,12 @@ end
96117
@inferred VarName{:a}(PropertyLens(:b))
97118
@inferred VarName{:a}(Accessors.opcompose(IndexLens(1), PropertyLens(:b)))
98119

99-
a = [1, 2, 3]
100-
@inferred get(a, @varname(a[1]))
101-
102120
b = (a=[1, 2, 3],)
103-
@inferred get(b, @varname(b.a[1]))
121+
@inferred get(b, @varname(a[1]))
104122
@inferred Accessors.set(b, @varname(a[1]), 10)
105123

106124
c = (b=(a=[1, 2, 3],),)
107-
@inferred get(c, @varname(c.b.a[1]))
125+
@inferred get(c, @varname(b.a[1]))
108126
@inferred Accessors.set(c, @varname(b.a[1]), 10)
109127
end
110128
end

0 commit comments

Comments
 (0)