Skip to content

Commit 648376d

Browse files
authored
Add variate transport
1 parent bf7eae6 commit 648376d

20 files changed

+742
-40
lines changed

Project.toml

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.10.0"
4+
version = "0.11.0"
55

66
[deps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
79
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
810
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
911
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
1012
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1113
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
14+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
1215
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1316
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1417
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
@@ -24,11 +27,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2427
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
2528

2629
[compat]
30+
ChainRulesCore = "1"
31+
ChangesOfVariables = "0.1.3"
2732
Compat = "3.35, 4"
2833
ConstructionBase = "1.3"
2934
DensityInterface = "0.4"
3035
FillArrays = "0.12, 0.13"
3136
IfElse = "0.1"
37+
InverseFunctions = "0.1.7"
3238
IrrationalConstants = "0.1"
3339
LogExpFunctions = "0.3"
3440
LogarithmicNumbers = "1"
@@ -42,6 +48,7 @@ julia = "1.3"
4248

4349
[extras]
4450
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
51+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
4552

4653
[targets]
47-
test = ["Aqua"]
54+
test = ["Aqua", "ChainRulesTestUtils"]

src/MeasureBase.jl

+14-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module MeasureBase
22

3+
using Base: @propagate_inbounds
4+
35
using Random
46
import Random: rand!
57
import Random: gentype
@@ -11,13 +13,17 @@ import DensityInterface: densityof
1113
import DensityInterface: DensityKind
1214
using DensityInterface
1315

16+
using InverseFunctions
17+
using ChangesOfVariables
18+
1419
import Base.iterate
1520
import ConstructionBase
1621
using ConstructionBase: constructorof
1722

1823
using PrettyPrinting
1924
const Pretty = PrettyPrinting
2025

26+
using ChainRulesCore
2127
using FillArrays
2228
using Static
2329

@@ -32,20 +38,11 @@ export logdensity_def
3238
export basemeasure
3339
export basekernel
3440
export productmeasure
35-
36-
"""
37-
inssupport(m, x)
38-
insupport(m)
39-
40-
`insupport(m,x)` computes whether `x` is in the support of `m`.
41-
42-
`insupport(m)` returns a function, and satisfies
43-
44-
insupport(m)(x) == insupport(m, x)
45-
"""
46-
function insupport end
47-
4841
export insupport
42+
export getdof
43+
export transport_to
44+
45+
include("insupport.jl")
4946

5047
abstract type AbstractMeasure end
5148

@@ -63,7 +60,7 @@ gentype(μ::AbstractMeasure) = typeof(testvalue(μ))
6360
# gentype(μ::AbstractMeasure) = gentype(basemeasure(μ))
6461

6562
using NaNMath
66-
using LogExpFunctions: logsumexp
63+
using LogExpFunctions: logsumexp, logistic, logit
6764

6865
@deprecate instance_type(x) Core.Typeof(x) false
6966

@@ -94,6 +91,8 @@ using Compat
9491

9592
using IrrationalConstants
9693

94+
include("getdof.jl")
95+
include("transport.jl")
9796
include("schema.jl")
9897
include("splat.jl")
9998
include("proxies.jl")
@@ -125,9 +124,9 @@ include("combinators/powerweighted.jl")
125124
include("combinators/conditional.jl")
126125

127126
include("standard/stdmeasure.jl")
128-
include("standard/stdnormal.jl")
129127
include("standard/stduniform.jl")
130128
include("standard/stdexponential.jl")
129+
include("standard/stdlogistic.jl")
131130

132131
include("rand.jl")
133132

src/combinators/power.jl

+28
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ end
8585
end
8686
end
8787

88+
@inline function logdensity_def(
89+
d::PowerMeasure{M,NTuple{N, Base.OneTo{StaticInt{0}}}},
90+
x,
91+
) where {M,N}
92+
static(0.0)
93+
end
94+
8895
@inline function insupport::PowerMeasure, x)
8996
p = μ.parent
9097
all(x) do xj
@@ -100,3 +107,24 @@ end
100107
dynamic(insupport(p, xj))
101108
end
102109
end
110+
111+
112+
@inline getdof::PowerMeasure) = getdof.parent) * prod(map(length, μ.axes))
113+
114+
@inline getdof(::PowerMeasure{<:Any, NTuple{N,Base.OneTo{StaticInt{0}}}}) where N = static(0)
115+
116+
117+
@propagate_inbounds function checked_var::PowerMeasure, x::AbstractArray{<:Any})
118+
@boundscheck begin
119+
sz_μ = map(length, μ.axes)
120+
sz_x = size(x)
121+
if sz_μ != sz_x
122+
throw(ArgumentError("Size of variate doesn't match size of power measure"))
123+
end
124+
end
125+
return x
126+
end
127+
128+
function checked_var::PowerMeasure, x::Any)
129+
throw(ArgumentError("Size of variate doesn't match size of power measure"))
130+
end

src/combinators/transformedmeasure.jl

+91
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,94 @@ function params(::AbstractTransformedMeasure) end
1313
function paramnames(::AbstractTransformedMeasure) end
1414

1515
function parent(::AbstractTransformedMeasure) end
16+
17+
18+
export PushforwardMeasure
19+
20+
"""
21+
struct PushforwardMeasure{FF,IF,MU,VC<:TransformVolCorr} <: AbstractPushforward
22+
f :: FF
23+
inv_f :: IF
24+
origin :: MU
25+
volcorr :: VC
26+
end
27+
"""
28+
struct PushforwardMeasure{FF,IF,M,VC<:TransformVolCorr} <: AbstractPushforward
29+
f::FF
30+
inv_f::IF
31+
origin::M
32+
volcorr::VC
33+
end
34+
35+
gettransform::PushforwardMeasure) = ν.f
36+
parent::PushforwardMeasure) = ν.origin
37+
38+
39+
function Pretty.tile::PushforwardMeasure)
40+
Pretty.list_layout(Pretty.tile.([ν.f, ν.inv_f, ν.origin]); prefix = :PushforwardMeasure)
41+
end
42+
43+
44+
@inline function logdensity_def::PushforwardMeasure{FF,IF,M,<:WithVolCorr}, y) where {FF,IF,M}
45+
x_orig, inv_ladj = with_logabsdet_jacobian.inv_f, y)
46+
logd_orig = logdensity_def.origin, x_orig)
47+
logd = float(logd_orig + inv_ladj)
48+
neginf = oftype(logd, -Inf)
49+
return ifelse(
50+
# Zero density wins against infinite volume:
51+
(isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) ||
52+
# Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
53+
# Return constant -Inf to prevent problems with ForwardDiff:
54+
(isfinite(logd_orig) && (inv_ladj == -Inf)),
55+
neginf,
56+
logd
57+
)
58+
end
59+
60+
@inline function logdensity_def::PushforwardMeasure{FF,IF,M,<:NoVolCorr}, y) where {FF,IF,M}
61+
x_orig = to_origin(ν, y)
62+
return logdensity_def.origin, x_orig)
63+
end
64+
65+
66+
insupport::PushforwardMeasure, y) = insupport(transport_origin(ν), to_origin(ν, y))
67+
68+
testvalue::PushforwardMeasure) = from_origin(ν, testvalue(transport_origin(ν)))
69+
70+
@inline function basemeasure::PushforwardMeasure)
71+
PushforwardMeasure.f, ν.inv_f, basemeasure(transport_origin(ν)), NoVolCorr())
72+
end
73+
74+
75+
_pushfwd_dof(::Type{MU}, ::Type, dof) where MU = NoDOF{MU}()
76+
_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where MU = dof
77+
78+
# Assume that DOF are preserved if with_logabsdet_jacobian is functional:
79+
@inline function getdof::MU) where {MU<:PushforwardMeasure}
80+
T = Core.Compiler.return_type(testvalue, Tuple{typeof.origin)})
81+
R = Core.Compiler.return_type(with_logabsdet_jacobian, Tuple{typeof.f), T})
82+
_pushfwd_dof(MU, R, getdof.origin))
83+
end
84+
85+
# Bypass `checked_var`, would require potentially costly transformation:
86+
@inline checked_var(::PushforwardMeasure, x) = x
87+
88+
89+
@inline transport_origin::PushforwardMeasure) = ν.origin
90+
@inline from_origin::PushforwardMeasure, x) = ν.f(x)
91+
@inline to_origin::PushforwardMeasure, y) = ν.inv_f(y)
92+
93+
function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where T
94+
return from_origin(ν, rand(rng, T, transport_origin(ν)))
95+
end
96+
97+
98+
export pushfwd
99+
100+
"""
101+
pushfwd(f, μ, volcorr = WithVolCorr())
102+
103+
Return the [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure)
104+
from `μ` the [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.
105+
"""
106+
pushfwd(f, μ, volcorr = WithVolCorr()) = PushforwardMeasure(f, inverse(f), μ, volcorr)

src/combinators/weighted.jl

+4
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,7 @@ Base.:*(m::AbstractMeasure, k::Real) = k * m
4848
gentype::WeightedMeasure) = gentype.base)
4949

5050
insupport::WeightedMeasure, x) = insupport.base, x)
51+
52+
transport_origin::WeightedMeasure) = ν.base
53+
to_origin(::WeightedMeasure, y) = y
54+
from_origin(::WeightedMeasure, x) = x

src/getdof.jl

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
MeasureBase.NoDOF{MU}
3+
4+
Indicates that there is no way to compute degrees of freedom of a measure
5+
of type `MU` with the given information, e.g. because the DOF are not
6+
a global property of the measure.
7+
"""
8+
struct NoDOF{MU} end
9+
10+
11+
"""
12+
getdof(μ)
13+
14+
Returns the effective number of degrees of freedom of variates of
15+
measure `μ`.
16+
17+
The effective NDOF my differ from the length of the variates. For example,
18+
the effective NDOF for a Dirichlet distribution with variates of length `n`
19+
is `n - 1`.
20+
21+
Also see [`check_dof`](@ref).
22+
"""
23+
function getdof end
24+
25+
# Prevent infinite recursion:
26+
@inline _default_getdof(::Type{MU}, ::MU) where MU = NoDOF{MU}
27+
@inline _default_getdof(::Type{MU}, mu_base) where MU = getdof(mu_base)
28+
29+
@inline getdof::MU) where MU = _default_getdof(MU, basemeasure(μ))
30+
31+
32+
"""
33+
MeasureBase.check_dof(ν, μ)::Nothing
34+
35+
Check if `ν` and `μ` have the same effective number of degrees of freedom
36+
according to [`MeasureBase.getdof`](@ref).
37+
"""
38+
function check_dof end
39+
40+
function check_dof(ν, μ)
41+
n_ν = getdof(ν)
42+
n_μ = getdof(μ)
43+
if n_ν != n_μ
44+
throw(ArgumentError("Measure ν of type $(nameof(typeof(ν))) has $(n_ν) DOF but μ of type $(nameof(typeof(μ))) has $(n_μ) DOF"))
45+
end
46+
return nothing
47+
end
48+
49+
_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
50+
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback
51+
52+
53+
"""
54+
MeasureBase.NoVarCheck{MU,T}
55+
56+
Indicates that there is no way to check of a values of type `T` are
57+
variate of measures of type `MU`.
58+
"""
59+
struct NoVarCheck{MU,T} end
60+
61+
62+
"""
63+
MeasureBase.checked_var(μ::MU, x::T)::T
64+
65+
Return `x` if `x` is a valid variate of `μ`, throw an `ArgumentError` if not,
66+
return `NoVarCheck{MU,T}()` if not check can be performed.
67+
"""
68+
function checked_var end
69+
70+
# Prevent infinite recursion:
71+
@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoVarCheck{MU,T}
72+
@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_var(mu_base, x)
73+
74+
@propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x)
75+
76+
_checked_var_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
77+
ChainRulesCore.rrule(::typeof(checked_var), ν, x) = checked_var(ν, x), _checked_var_pullback

src/insupport.jl

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""
2+
inssupport(m, x)
3+
insupport(m)
4+
5+
`insupport(m,x)` computes whether `x` is in the support of `m`.
6+
7+
`insupport(m)` returns a function, and satisfies
8+
9+
insupport(m)(x) == insupport(m, x)
10+
"""
11+
function insupport end
12+
13+
14+
"""
15+
MeasureBase.require_insupport(μ, x)::Nothing
16+
17+
Checks if `x` is in the support of distribution/measure `μ`, throws an
18+
`ArgumentError` if not.
19+
"""
20+
function require_insupport end
21+
22+
_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent()
23+
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
24+
return require_insupport(μ, x), _require_insupport_pullback
25+
end
26+
27+
function require_insupport(μ, x)
28+
if !insupport(μ, x)
29+
throw(ArgumentError("x is not within the support of μ"))
30+
end
31+
return nothing
32+
end

0 commit comments

Comments
 (0)