Skip to content

Commit 042882d

Browse files
committed
Don't require primary in HierarchicalMeasure to have known DOF
1 parent 73b66a6 commit 042882d

File tree

3 files changed

+167
-68
lines changed

3 files changed

+167
-68
lines changed

src/combinators/hierarchical.jl

+149-64
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,198 @@
11
export HierarchicalMeasure
22

33

4-
struct HierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure
4+
# TODO: Document and use FlattenMode
5+
abstract type FlattenMode end
6+
struct NoFlatten <: FlattenMode end
7+
struct AutoFlatten <: FlattenMode end
8+
9+
10+
struct HierarchicalMeasure{F,M<:AbstractMeasure,FM<:FlattenMode} <: AbstractMeasure
511
f::F
612
m::M
7-
dof_m::Int
13+
flatten_mode::FM
814
end
915

16+
# TODO: Document
17+
const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,NoFlatten}
18+
export HierarchicalProductMeasure
1019

11-
function HierarchicalMeasure(f, m::AbstractMeasure, ::NoDOF)
12-
throw(ArgumentError("Primary measure in HierarchicalMeasure must have fixed and known DOF"))
13-
end
20+
HierarchicalProductMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, NoFlatten())
21+
22+
# TODO: Document
23+
const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,AutoFlatten}
24+
export FlatHierarchicalMeasure
25+
26+
FlatHierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, AutoFlatten())
27+
28+
HierarchicalMeasure(f, m::AbstractMeasure) = FlatHierarchicalMeasure(f, m)
1429

15-
HierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, dynamic(getdof(m)))
1630

1731

18-
function _split_variate(h::HierarchicalMeasure, x)
19-
# TODO: Splitting x will be more complicated in general:
20-
x_primary, x_secondary = x
21-
return (x_primary, x_secondary)
32+
function _split_variate_after(::NoFlatten, μ::AbstractMeasure, x::Tuple{2})
33+
@assert x isa Tuple{2}
34+
return x[1], x[2]
2235
end
2336

2437

25-
function _combine_variates(x_primary, x_secondary)
26-
# TODO: Must offer optional flattening
27-
return (x_primary, x_secondary)
38+
function _split_variate_after(::AutoFlatten, μ::AbstractMeasure, x)
39+
a_test = testvalue(μ)
40+
return _autosplit_variate_after_testvalue(a_test, x)
2841
end
2942

43+
function _autosplit_variate_after_testvalue(::Any, x)
44+
@assert x isa Tuple{2}
45+
return x[1], x[2]
46+
end
3047

31-
function localmeasure(h::HierarchicalMeasure, x)
32-
x_primary, x_secondary = _split_variate(h, x)
33-
m_primary = h.m
34-
m_primary_local = localmeasure(m_primary, x_primary)
35-
m_secondary = m.f(x_secondary)
36-
m_secondary_local = localmeasure(m_secondary, x_secondary)
37-
# TODO: Must optionally return a flattened product measure
38-
return productmeasure(m_primary_local, m_secondary_local)
48+
function _autosplit_variate_after_testvalue(a_test::AbstractVector, x::AbstractVector)
49+
n, m = length(eachindex(a_test)), length(eachindex(x))
50+
# TODO: Use getindex or view?
51+
return x[begin:n], x[begin+n:m]
3952
end
4053

54+
function _autosplit_variate_after_testvalue(::Tuple{N}, x::Tuple{M}) where {N,M}
55+
return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M))
56+
end
4157

42-
@inline function insupport(h::HierarchicalMeasure, x)
43-
# Only test primary for efficiency:
44-
x_primary = _split_variate(h, x)[1]
45-
insupport(h.m, x_primary)
58+
@generated function _autosplit_variate_after_testvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names}
59+
# TODO: implement
60+
@assert false
4661
end
4762

4863

49-
#!!!!!!! WON'T WORK: Only use primary measure for efficiency:
50-
logdensity_type(h::HierarchicalMeasure{F,M}, ::Type{T}) where {F,M,T} = unstatic(float(logdensity_type(M, T)))
5164

52-
# Can't implement logdensity_def(::HierarchicalMeasure, x) directly.
65+
_combine_variates(::NoFlatten, a::Any, b::Any) = (a, b)
66+
67+
68+
_combine_variates(::AutoFlatten, a::Any, b::Any) = _autoflat_combine_variates(a, b)
69+
70+
_autoflat_combine_variates(a::Any, b::Any) = (a, b)
71+
72+
_autoflat_combine_variates(a::AbstractVector, b::AbstractVector) = vcat(a, b)
73+
74+
_autoflat_combine_variates(a::Tuple, b::Tuple) = (a, b)
75+
76+
# TODO: Check that names don't overlap:
77+
_autoflat_combine_variates(a::NamedTuple, b::NamedTuple) = merge(a, b)
78+
79+
80+
_local_productmeasure(::NoFlatten, μ1, μ2) = productmeasure(μ1, μ2)
81+
82+
# TODO: _local_productmeasure(::AutoFlatten, μ1, μ2) = productmeasure(μ1, μ2)
83+
# Needs a FlatProductMeasure type.
84+
85+
function _localmeasure_with_rest::HierarchicalProductMeasure, x)
86+
μ_primary = μ.m
87+
local_primary, x_secondary = _localmeasure_with_rest(μ_primary, x)
88+
μ_secondary = μ.f(x_secondary)
89+
local_secondary, x_rest = _localmeasure_with_rest(μ_secondary, x_secondary)
90+
return _local_productmeasure.flatten_mode, local_primary, local_secondary), x_rest
91+
end
92+
93+
function _localmeasure_with_rest::AbstractMeasure, x)
94+
x_checked = checked_arg(μ, x)
95+
return localmeasure(μ, x_checked), Fill(zero(eltype(x)), 0)
96+
end
97+
98+
function localmeasure::HierarchicalProductMeasure, x)
99+
h_local, x_rest = _localmeasure_with_rest(μ, x)
100+
if !isempty(x_rest)
101+
throw(ArgumentError("Variate too long while computing localmeasure of HierarchicalMeasure"))
102+
end
103+
return h_local
104+
end
53105

54-
# Can't implement getdof(::HierarchicalMeasure) efficiently
55106

56-
# No way to return a functional base measure:
57-
struct _BaseOfHierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure end
58-
@inline basemeasure(::HierarchicalMeasure{F,M}) where {F,M} = _BaseOfHierarchicalMeasure{F,M}()
107+
@inline insupport(::HierarchicalMeasure, x) = NoFastInsupport()
59108

60109
@inline getdof::HierarchicalMeasure) = NoDOF{typeof(μ)}()
61110

62111
# Bypass `checked_arg`, would require potentially costly evaluation of h.f:
63112
@inline checked_arg(::HierarchicalMeasure, x) = x
64113

65-
function unsafe_logdensityof(h::HierarchicalMeasure, x)
66-
x_primary, x_secondary = _split_variate(h, x)
67-
h_primary, h_secondary = h.m, h.f(x_secondary)
68-
unsafe_logdensityof(h_primary, x_primary) + logdensityof(h_secondary, x_secondary)
69-
end
114+
rootmeasure(::HierarchicalMeasure) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for HierarchicalMeasure"))
115+
116+
basemeasure(::HierarchicalMeasure) = throw(ArgumentError("basemeasure is not available for HierarchicalMeasure"))
117+
118+
logdensity_def(::HierarchicalMeasure, x) = throw(ArgumentError("logdensity_def is not available for HierarchicalMeasure"))
119+
120+
121+
# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient
122+
# # for AutoFlatten, since variate will be split in `localmeasure` and then
123+
# # split again in log-density evaluation. Maybe add something like
124+
# function unsafe_logdensityof(h::HierarchicalMeasure, x)
125+
# local_primary, local_secondary, x_primary, x_secondary = ...
126+
# # Need to call full logdensityof for h_secondary since x_secondary hasn't
127+
# # been checked yet:
128+
# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary)
129+
# end
70130

71131

72132
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::HierarchicalMeasure) where {T<:Real}
73133
x_primary = rand(rng, T, h.m)
74134
x_secondary = rand(rng, T, h.f(x_primary))
75-
return _combine_variates(x_primary, x_secondary)
135+
return _combine_variates(h.flatten_mode, x_primary, x_secondary)
76136
end
77137

78138

79-
function _split_measure_at::PowerMeasure{M, Tuple{R}}, n::Integer) where {M<:StdMeasure,R}
80-
dof_μ = getdof(μ)
81-
return M()^n, M()^(dof_μ - n)
82-
end
83-
84139

85-
function transport_def(
86-
ν::PowerMeasure{M, Tuple{R}},
87-
μ::HierarchicalMeasure,
88-
x,
89-
) where {M<:StdMeasure,R}
90-
ν_primary, ν_secondary = _split_measure_at(ν, μ.dof_m)
91-
x_primary, x_secondary = _split_variate(μ, x)
140+
function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::HierarchicalMeasure, x)
92141
μ_primary = μ.m
142+
y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x)
93143
μ_secondary = μ.f(x_secondary)
94-
y_primary = transport_to(ν_primary, μ_primary, x_primary)
95-
y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary)
96-
return vcat(y_primary, y_secondary)
144+
y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary)
145+
return _combine_variates.flatten_mode, y_primary, y_secondary), x_rest
146+
end
147+
148+
function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x)
149+
dof_μ = getdof(μ)
150+
x_μ, x_rest = _split_variate_after(flatten_mode, μ, x)
151+
y = transport_to(ν_inner^dof_μ, μ, x_μ)
152+
return y, x_rest
153+
end
154+
155+
function transport_def::_PowerStdMeasure{1}, μ::HierarchicalMeasure, x)
156+
ν_inner = _get_inner_stdmeasure(ν)
157+
y, x_rest = _to_std_with_rest(ν_inner, μ, x)
158+
if !isempty(x_rest)
159+
throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure"))
160+
end
161+
return y
97162
end
98163

99164

100-
function transport_def(
101-
ν::HierarchicalMeasure,
102-
μ::PowerMeasure{M, Tuple{R}},
103-
x,
104-
) where {M<:StdMeasure,R}
105-
dof_primary = ν.dof_m
106-
μ_primary, μ_secondary = _split_measure_at(μ, dof_primary)
107-
x_primary, x_secondary = x[begin:begin+dof_primary-1], x[begin+dof_primary:end]
165+
function _from_std_with_rest::HierarchicalMeasure, μ_inner::StdMeasure, x)
108166
ν_primary = ν.m
109-
y_primary = transport_to(ν_primary, μ_primary, x_primary)
167+
y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x)
110168
ν_secondary = ν.f(y_primary)
111-
y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary)
112-
return _combine_variates(y_primary, y_secondary)
169+
y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary)
170+
return _combine_variates.flatten_mode, y_primary, y_secondary), x_rest
171+
end
172+
173+
function _from_std_with_rest::AbstractMeasure, μ_inner::StdMeasure, x)
174+
dof_ν = getdof(ν)
175+
len_x = length(eachindex(x))
176+
177+
# Since we can't check DOF of original HierarchicalMeasure, we could "run out x" if
178+
# the original x was too short. `transport_to` below will detect this, but better
179+
# throw a more informative exception here:
180+
if len_x < dof_ν
181+
throw(ArgumentError("Variate too short during transport involving HierarchicalMeasure"))
182+
end
183+
184+
y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1])
185+
x_rest = Fill(zero(eltype(x)), dof_ν - len_x)
186+
return y, x_rest
187+
end
188+
189+
function transport_def::HierarchicalMeasure, μ::_PowerStdMeasure{1}, x)
190+
# Sanity check, should be checked by transport machinery already:
191+
@assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector
192+
μ_inner = _get_inner_stdmeasure(μ)
193+
y, x_rest = _from_std_with_rest(ν, μ_inner, x)
194+
if !isempty(x_rest)
195+
throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure"))
196+
end
197+
return y
113198
end

src/density-core.jl

+12-4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ To compute a log-density relative to a specific base-measure, see
5555
end
5656

5757
_checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))
58+
@inline _checksupport(::NoFastInsupport, result) = result
5859

5960
import ChainRulesCore
6061
@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result)
@@ -77,6 +78,12 @@ This is "unsafe" because it does not check `insupport(m, x)`.
7778
See also `logdensityof`.
7879
"""
7980
@inline function unsafe_logdensityof::M, x) where {M}
81+
μ_local = localmeasure(μ, x)
82+
# Extra dispatch boundary to reduce number of required specializations of implementation:
83+
return _unsafe_logdensityof_local(μ_local, x)
84+
end
85+
86+
@inline function _unsafe_logdensityof_local::M, x) where {M}
8087
ℓ_0 = logdensity_def(μ, x)
8188
b_0 = μ
8289
Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number
@@ -119,7 +126,7 @@ known to be in the support of both, it can be more efficient to call
119126
end
120127

121128

122-
function _logdensity_rel_impl::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X}
129+
@inline function _logdensity_rel_impl::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X}
123130
T = unstatic(
124131
promote_type(
125132
logdensity_type(μ, X),
@@ -134,16 +141,16 @@ function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, inν::Bool) where
134141
end
135142

136143

137-
function _logdensity_rel_impl::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X}
144+
@inline function _logdensity_rel_impl::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X}
138145
unsafe_logdensity_rel(μ, ν, x)
139146
end
140147

141-
function _logdensity_rel_impl::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X}
148+
@inline function _logdensity_rel_impl::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X}
142149
logd = unsafe_logdensity_rel(μ, ν, x)
143150
return istrue(inμ) ? logd : logd * oftypeof(logd, -Inf)
144151
end
145152

146-
function _logdensity_rel_impl::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X}
153+
@inline function _logdensity_rel_impl::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X}
147154
logd = unsafe_logdensity_rel(μ, ν, x)
148155
return istrue(inν) ? logd : logd * oftypeof(logd, +Inf)
149156
end
@@ -160,6 +167,7 @@ See also `logdensity_rel`.
160167
@inline function unsafe_logdensity_rel::M, ν::N, x::X) where {M,N,X}
161168
μ_local = localmeasure(μ, x)
162169
ν_local = localmeasure(ν, x)
170+
# Extra dispatch boundary to reduce number of required specializations of implementation:
163171
return _unsafe_logdensity_rel_local(μ_local, ν_local, x)
164172
end
165173

src/standard/stdmeasure.jl

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
abstract type StdMeasure <: AbstractMeasure end
22

3+
4+
const _PowerStdMeasure{N,M<:StdMeasure} = PowerMeasure{M,<:NTuple{N,Base.OneTo}}
5+
6+
_get_inner_stdmeasure::_PowerStdMeasure{N,M}) where {N,M} = M()
7+
8+
39
StdMeasure(::typeof(rand)) = StdUniform()
410
StdMeasure(::typeof(randexp)) = StdExponential()
511
StdMeasure(::typeof(randn)) = StdNormal()

0 commit comments

Comments
 (0)