|
1 | 1 | export HierarchicalMeasure
|
2 | 2 |
|
3 | 3 |
|
4 |
| -struct HierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure |
| 4 | +# TODO: Document and use FlattenMode |
| 5 | +abstract type FlattenMode end |
| 6 | +struct NoFlatten <: FlattenMode end |
| 7 | +struct AutoFlatten <: FlattenMode end |
| 8 | + |
| 9 | + |
| 10 | +struct HierarchicalMeasure{F,M<:AbstractMeasure,FM<:FlattenMode} <: AbstractMeasure |
5 | 11 | f::F
|
6 | 12 | m::M
|
7 |
| - dof_m::Int |
| 13 | + flatten_mode::FM |
8 | 14 | end
|
9 | 15 |
|
| 16 | +# TODO: Document |
| 17 | +const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,NoFlatten} |
| 18 | +export HierarchicalProductMeasure |
10 | 19 |
|
11 |
| -function HierarchicalMeasure(f, m::AbstractMeasure, ::NoDOF) |
12 |
| - throw(ArgumentError("Primary measure in HierarchicalMeasure must have fixed and known DOF")) |
13 |
| -end |
| 20 | +HierarchicalProductMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, NoFlatten()) |
| 21 | + |
| 22 | +# TODO: Document |
| 23 | +const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,AutoFlatten} |
| 24 | +export FlatHierarchicalMeasure |
| 25 | + |
| 26 | +FlatHierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, AutoFlatten()) |
| 27 | + |
| 28 | +HierarchicalMeasure(f, m::AbstractMeasure) = FlatHierarchicalMeasure(f, m) |
14 | 29 |
|
15 |
| -HierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, dynamic(getdof(m))) |
16 | 30 |
|
17 | 31 |
|
18 |
| -function _split_variate(h::HierarchicalMeasure, x) |
19 |
| - # TODO: Splitting x will be more complicated in general: |
20 |
| - x_primary, x_secondary = x |
21 |
| - return (x_primary, x_secondary) |
| 32 | +function _split_variate_after(::NoFlatten, μ::AbstractMeasure, x::Tuple{2}) |
| 33 | + @assert x isa Tuple{2} |
| 34 | + return x[1], x[2] |
22 | 35 | end
|
23 | 36 |
|
24 | 37 |
|
25 |
| -function _combine_variates(x_primary, x_secondary) |
26 |
| - # TODO: Must offer optional flattening |
27 |
| - return (x_primary, x_secondary) |
| 38 | +function _split_variate_after(::AutoFlatten, μ::AbstractMeasure, x) |
| 39 | + a_test = testvalue(μ) |
| 40 | + return _autosplit_variate_after_testvalue(a_test, x) |
28 | 41 | end
|
29 | 42 |
|
| 43 | +function _autosplit_variate_after_testvalue(::Any, x) |
| 44 | + @assert x isa Tuple{2} |
| 45 | + return x[1], x[2] |
| 46 | +end |
30 | 47 |
|
31 |
| -function localmeasure(h::HierarchicalMeasure, x) |
32 |
| - x_primary, x_secondary = _split_variate(h, x) |
33 |
| - m_primary = h.m |
34 |
| - m_primary_local = localmeasure(m_primary, x_primary) |
35 |
| - m_secondary = m.f(x_secondary) |
36 |
| - m_secondary_local = localmeasure(m_secondary, x_secondary) |
37 |
| - # TODO: Must optionally return a flattened product measure |
38 |
| - return productmeasure(m_primary_local, m_secondary_local) |
| 48 | +function _autosplit_variate_after_testvalue(a_test::AbstractVector, x::AbstractVector) |
| 49 | + n, m = length(eachindex(a_test)), length(eachindex(x)) |
| 50 | + # TODO: Use getindex or view? |
| 51 | + return x[begin:n], x[begin+n:m] |
39 | 52 | end
|
40 | 53 |
|
| 54 | +function _autosplit_variate_after_testvalue(::Tuple{N}, x::Tuple{M}) where {N,M} |
| 55 | + return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M)) |
| 56 | +end |
41 | 57 |
|
42 |
| -@inline function insupport(h::HierarchicalMeasure, x) |
43 |
| - # Only test primary for efficiency: |
44 |
| - x_primary = _split_variate(h, x)[1] |
45 |
| - insupport(h.m, x_primary) |
| 58 | +@generated function _autosplit_variate_after_testvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} |
| 59 | + # TODO: implement |
| 60 | + @assert false |
46 | 61 | end
|
47 | 62 |
|
48 | 63 |
|
49 |
| -#!!!!!!! WON'T WORK: Only use primary measure for efficiency: |
50 |
| -logdensity_type(h::HierarchicalMeasure{F,M}, ::Type{T}) where {F,M,T} = unstatic(float(logdensity_type(M, T))) |
51 | 64 |
|
52 |
| -# Can't implement logdensity_def(::HierarchicalMeasure, x) directly. |
| 65 | +_combine_variates(::NoFlatten, a::Any, b::Any) = (a, b) |
| 66 | + |
| 67 | + |
| 68 | +_combine_variates(::AutoFlatten, a::Any, b::Any) = _autoflat_combine_variates(a, b) |
| 69 | + |
| 70 | +_autoflat_combine_variates(a::Any, b::Any) = (a, b) |
| 71 | + |
| 72 | +_autoflat_combine_variates(a::AbstractVector, b::AbstractVector) = vcat(a, b) |
| 73 | + |
| 74 | +_autoflat_combine_variates(a::Tuple, b::Tuple) = (a, b) |
| 75 | + |
| 76 | +# TODO: Check that names don't overlap: |
| 77 | +_autoflat_combine_variates(a::NamedTuple, b::NamedTuple) = merge(a, b) |
| 78 | + |
| 79 | + |
| 80 | +_local_productmeasure(::NoFlatten, μ1, μ2) = productmeasure(μ1, μ2) |
| 81 | + |
| 82 | +# TODO: _local_productmeasure(::AutoFlatten, μ1, μ2) = productmeasure(μ1, μ2) |
| 83 | +# Needs a FlatProductMeasure type. |
| 84 | + |
| 85 | +function _localmeasure_with_rest(μ::HierarchicalProductMeasure, x) |
| 86 | + μ_primary = μ.m |
| 87 | + local_primary, x_secondary = _localmeasure_with_rest(μ_primary, x) |
| 88 | + μ_secondary = μ.f(x_secondary) |
| 89 | + local_secondary, x_rest = _localmeasure_with_rest(μ_secondary, x_secondary) |
| 90 | + return _local_productmeasure(μ.flatten_mode, local_primary, local_secondary), x_rest |
| 91 | +end |
| 92 | + |
| 93 | +function _localmeasure_with_rest(μ::AbstractMeasure, x) |
| 94 | + x_checked = checked_arg(μ, x) |
| 95 | + return localmeasure(μ, x_checked), Fill(zero(eltype(x)), 0) |
| 96 | +end |
| 97 | + |
| 98 | +function localmeasure(μ::HierarchicalProductMeasure, x) |
| 99 | + h_local, x_rest = _localmeasure_with_rest(μ, x) |
| 100 | + if !isempty(x_rest) |
| 101 | + throw(ArgumentError("Variate too long while computing localmeasure of HierarchicalMeasure")) |
| 102 | + end |
| 103 | + return h_local |
| 104 | +end |
53 | 105 |
|
54 |
| -# Can't implement getdof(::HierarchicalMeasure) efficiently |
55 | 106 |
|
56 |
| -# No way to return a functional base measure: |
57 |
| -struct _BaseOfHierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure end |
58 |
| -@inline basemeasure(::HierarchicalMeasure{F,M}) where {F,M} = _BaseOfHierarchicalMeasure{F,M}() |
| 107 | +@inline insupport(::HierarchicalMeasure, x) = NoFastInsupport() |
59 | 108 |
|
60 | 109 | @inline getdof(μ::HierarchicalMeasure) = NoDOF{typeof(μ)}()
|
61 | 110 |
|
62 | 111 | # Bypass `checked_arg`, would require potentially costly evaluation of h.f:
|
63 | 112 | @inline checked_arg(::HierarchicalMeasure, x) = x
|
64 | 113 |
|
65 |
| -function unsafe_logdensityof(h::HierarchicalMeasure, x) |
66 |
| - x_primary, x_secondary = _split_variate(h, x) |
67 |
| - h_primary, h_secondary = h.m, h.f(x_secondary) |
68 |
| - unsafe_logdensityof(h_primary, x_primary) + logdensityof(h_secondary, x_secondary) |
69 |
| -end |
| 114 | +rootmeasure(::HierarchicalMeasure) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for HierarchicalMeasure")) |
| 115 | + |
| 116 | +basemeasure(::HierarchicalMeasure) = throw(ArgumentError("basemeasure is not available for HierarchicalMeasure")) |
| 117 | + |
| 118 | +logdensity_def(::HierarchicalMeasure, x) = throw(ArgumentError("logdensity_def is not available for HierarchicalMeasure")) |
| 119 | + |
| 120 | + |
| 121 | +# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient |
| 122 | +# # for AutoFlatten, since variate will be split in `localmeasure` and then |
| 123 | +# # split again in log-density evaluation. Maybe add something like |
| 124 | +# function unsafe_logdensityof(h::HierarchicalMeasure, x) |
| 125 | +# local_primary, local_secondary, x_primary, x_secondary = ... |
| 126 | +# # Need to call full logdensityof for h_secondary since x_secondary hasn't |
| 127 | +# # been checked yet: |
| 128 | +# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary) |
| 129 | +# end |
70 | 130 |
|
71 | 131 |
|
72 | 132 | function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::HierarchicalMeasure) where {T<:Real}
|
73 | 133 | x_primary = rand(rng, T, h.m)
|
74 | 134 | x_secondary = rand(rng, T, h.f(x_primary))
|
75 |
| - return _combine_variates(x_primary, x_secondary) |
| 135 | + return _combine_variates(h.flatten_mode, x_primary, x_secondary) |
76 | 136 | end
|
77 | 137 |
|
78 | 138 |
|
79 |
| -function _split_measure_at(μ::PowerMeasure{M, Tuple{R}}, n::Integer) where {M<:StdMeasure,R} |
80 |
| - dof_μ = getdof(μ) |
81 |
| - return M()^n, M()^(dof_μ - n) |
82 |
| -end |
83 |
| - |
84 | 139 |
|
85 |
| -function transport_def( |
86 |
| - ν::PowerMeasure{M, Tuple{R}}, |
87 |
| - μ::HierarchicalMeasure, |
88 |
| - x, |
89 |
| -) where {M<:StdMeasure,R} |
90 |
| - ν_primary, ν_secondary = _split_measure_at(ν, μ.dof_m) |
91 |
| - x_primary, x_secondary = _split_variate(μ, x) |
| 140 | +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::HierarchicalMeasure, x) |
92 | 141 | μ_primary = μ.m
|
| 142 | + y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) |
93 | 143 | μ_secondary = μ.f(x_secondary)
|
94 |
| - y_primary = transport_to(ν_primary, μ_primary, x_primary) |
95 |
| - y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary) |
96 |
| - return vcat(y_primary, y_secondary) |
| 144 | + y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary) |
| 145 | + return _combine_variates(μ.flatten_mode, y_primary, y_secondary), x_rest |
| 146 | +end |
| 147 | + |
| 148 | +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) |
| 149 | + dof_μ = getdof(μ) |
| 150 | + x_μ, x_rest = _split_variate_after(flatten_mode, μ, x) |
| 151 | + y = transport_to(ν_inner^dof_μ, μ, x_μ) |
| 152 | + return y, x_rest |
| 153 | +end |
| 154 | + |
| 155 | +function transport_def(ν::_PowerStdMeasure{1}, μ::HierarchicalMeasure, x) |
| 156 | + ν_inner = _get_inner_stdmeasure(ν) |
| 157 | + y, x_rest = _to_std_with_rest(ν_inner, μ, x) |
| 158 | + if !isempty(x_rest) |
| 159 | + throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure")) |
| 160 | + end |
| 161 | + return y |
97 | 162 | end
|
98 | 163 |
|
99 | 164 |
|
100 |
| -function transport_def( |
101 |
| - ν::HierarchicalMeasure, |
102 |
| - μ::PowerMeasure{M, Tuple{R}}, |
103 |
| - x, |
104 |
| -) where {M<:StdMeasure,R} |
105 |
| - dof_primary = ν.dof_m |
106 |
| - μ_primary, μ_secondary = _split_measure_at(μ, dof_primary) |
107 |
| - x_primary, x_secondary = x[begin:begin+dof_primary-1], x[begin+dof_primary:end] |
| 165 | +function _from_std_with_rest(ν::HierarchicalMeasure, μ_inner::StdMeasure, x) |
108 | 166 | ν_primary = ν.m
|
109 |
| - y_primary = transport_to(ν_primary, μ_primary, x_primary) |
| 167 | + y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x) |
110 | 168 | ν_secondary = ν.f(y_primary)
|
111 |
| - y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary) |
112 |
| - return _combine_variates(y_primary, y_secondary) |
| 169 | + y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary) |
| 170 | + return _combine_variates(ν.flatten_mode, y_primary, y_secondary), x_rest |
| 171 | +end |
| 172 | + |
| 173 | +function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) |
| 174 | + dof_ν = getdof(ν) |
| 175 | + len_x = length(eachindex(x)) |
| 176 | + |
| 177 | + # Since we can't check DOF of original HierarchicalMeasure, we could "run out x" if |
| 178 | + # the original x was too short. `transport_to` below will detect this, but better |
| 179 | + # throw a more informative exception here: |
| 180 | + if len_x < dof_ν |
| 181 | + throw(ArgumentError("Variate too short during transport involving HierarchicalMeasure")) |
| 182 | + end |
| 183 | + |
| 184 | + y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1]) |
| 185 | + x_rest = Fill(zero(eltype(x)), dof_ν - len_x) |
| 186 | + return y, x_rest |
| 187 | +end |
| 188 | + |
| 189 | +function transport_def(ν::HierarchicalMeasure, μ::_PowerStdMeasure{1}, x) |
| 190 | + # Sanity check, should be checked by transport machinery already: |
| 191 | + @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector |
| 192 | + μ_inner = _get_inner_stdmeasure(μ) |
| 193 | + y, x_rest = _from_std_with_rest(ν, μ_inner, x) |
| 194 | + if !isempty(x_rest) |
| 195 | + throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure")) |
| 196 | + end |
| 197 | + return y |
113 | 198 | end
|
0 commit comments