Skip to content

Commit dd9d55c

Browse files
committed
STASH
1 parent 042882d commit dd9d55c

File tree

1 file changed

+35
-29
lines changed

1 file changed

+35
-29
lines changed

src/combinators/hierarchical.jl

+35-29
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,67 @@
11
export HierarchicalMeasure
22

3+
"""
4+
struct HierarchicalMeasure{F,M<:AbstractMeasure,G} <: AbstractMeasure
35
4-
# TODO: Document and use FlattenMode
5-
abstract type FlattenMode end
6-
struct NoFlatten <: FlattenMode end
7-
struct AutoFlatten <: FlattenMode end
6+
Represents a hierarchical measure.
87
9-
10-
struct HierarchicalMeasure{F,M<:AbstractMeasure,FM<:FlattenMode} <: AbstractMeasure
8+
User code should not instantiate `HierarchicalMeasure` directly, use
9+
[`hierarchical_measure`](@ref) instead.
10+
"""
11+
struct HierarchicalMeasure{F,M<:AbstractMeasure,G} <: AbstractMeasure
1112
f::F
1213
m::M
13-
flatten_mode::FM
14+
flatten::G
1415
end
1516

16-
# TODO: Document
17-
const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,NoFlatten}
18-
export HierarchicalProductMeasure
19-
20-
HierarchicalProductMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, NoFlatten())
17+
"""
18+
hierarchical_measure(f, m::AbstractMeasure, flatten)
2119
22-
# TODO: Document
23-
const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,AutoFlatten}
24-
export FlatHierarchicalMeasure
20+
Construct a hierarchical measure from a function `f`, measure `m` and
21+
"""
22+
@inline function hierarchical_measure(f, m::AbstractMeasure, flatten)
23+
F, M, G = Core.Typeof(f), Core.Typeof(m), Core.Typeof(flatten)
24+
HierarchicalProductMeasure{F,M,G}(f, m, flatten)
25+
end
2526

26-
FlatHierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, AutoFlatten())
2727

28-
HierarchicalMeasure(f, m::AbstractMeasure) = FlatHierarchicalMeasure(f, m)
28+
#!!!!!!
29+
const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,::typeof(=>)}
30+
const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,::typeof(vcat)}
2931

3032

3133

32-
function _split_variate_after(::NoFlatten, μ::AbstractMeasure, x::Tuple{2})
33-
@assert x isa Tuple{2}
34-
return x[1], x[2]
34+
function _split_variate(::typeof(=>), ::AbstractMeasure, x::Pair)
35+
return x.first, x.second
3536
end
3637

38+
function _split_variate(flatten::F, μ_primary::AbstractMeasure, x) where F
39+
test_primary = testvalue(μ_primary)
40+
return _split_variate_byvalue(flatten, test_primary, x)
41+
end
3742

38-
function _split_variate_after(::AutoFlatten, μ::AbstractMeasure, x)
39-
a_test = testvalue(μ)
40-
return _autosplit_variate_after_testvalue(a_test, x)
43+
function _split_variate(::Type{F}, μ::AbstractMeasure, x) where F
44+
test_primary = testvalue(μ)
45+
return _split_variate_byvalue(F, test_primary, x)
4146
end
4247

43-
function _autosplit_variate_after_testvalue(::Any, x)
48+
49+
function _split_variate_byvalue(::Any, x)
4450
@assert x isa Tuple{2}
4551
return x[1], x[2]
4652
end
4753

48-
function _autosplit_variate_after_testvalue(a_test::AbstractVector, x::AbstractVector)
49-
n, m = length(eachindex(a_test)), length(eachindex(x))
54+
function _split_variate_byvalue(test_primary::AbstractVector, x::AbstractVector)
55+
n, m = length(eachindex(test_primary)), length(eachindex(x))
5056
# TODO: Use getindex or view?
5157
return x[begin:n], x[begin+n:m]
5258
end
5359

54-
function _autosplit_variate_after_testvalue(::Tuple{N}, x::Tuple{M}) where {N,M}
60+
function _split_variate_byvalue(::Tuple{N}, x::Tuple{M}) where {N,M}
5561
return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M))
5662
end
5763

58-
@generated function _autosplit_variate_after_testvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names}
64+
@generated function _split_variate_byvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names}
5965
# TODO: implement
6066
@assert false
6167
end
@@ -147,7 +153,7 @@ end
147153

148154
function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x)
149155
dof_μ = getdof(μ)
150-
x_μ, x_rest = _split_variate_after(flatten_mode, μ, x)
156+
x_μ, x_rest = _split_variate(flatten_mode, μ, x)
151157
y = transport_to(ν_inner^dof_μ, μ, x_μ)
152158
return y, x_rest
153159
end

0 commit comments

Comments
 (0)