Skip to content

Commit 70a1117

Browse files
authored
KeywordCalls compatibility (#93)
> # Changes > > **Dependencies** > > * now using KeywordCalls.jl > > * Early stages of trait interface for `IsPrimitive` and `IsRepresentative`, using SimpleTraits.jl (see `traits.jl`) > > * New Distributions.jl version (no real changes here from that) > > * New TransformVariables has breaking changes, see `transforms.jl` > > > **Renaming** > > * `@measure` is too vague, changed to `@parameterized` > > * `ElementwiseProductMeasure` is a bit much, it's now `PointwiseProductMeasure` > > * `Likelihood` is now `LogLikelihood`, and it's callable as a function > > * `src/probability` is now `sr`
1 parent 69b7fc5 commit 70a1117

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+561
-519
lines changed

Project.toml

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureTheory"
22
uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.6.2"
4+
version = "0.7.0"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -13,6 +13,7 @@ DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38"
1313
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
1414
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
1515
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
16+
KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f"
1617
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1718
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
1819
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -22,6 +23,7 @@ NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
2223
NestedTuples = "a734d2a7-8d68-409b-9419-626914d4061d"
2324
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2425
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
26+
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
2527
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2628
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2729
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -33,23 +35,25 @@ Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
3335
Accessors = "0.1"
3436
ArrayInterface = "2,3"
3537
ConcreteStructs = "0.2"
36-
Distributions = "0.23, 0.24"
38+
Distributions = "0.23, 0.24, 0.25"
3739
DistributionsAD = "0.6"
3840
DynamicIterators = "0.4.2"
3941
InfiniteArrays = "0.7, 0.8, 0.9, 0.10"
4042
IntervalSets = "0.5"
4143
IterTools = "1"
44+
KeywordCalls = "0.1.5"
4245
MLStyle = "0.4"
4346
MacroTools = "0.5"
4447
MappedArrays = "0.3, 0.4"
4548
MonteCarloMeasurements = "0.10"
4649
NamedTupleTools = "0.13"
4750
NestedTuples = "0.3"
4851
RandomNumbers = "1"
52+
SimpleTraits = "0.9"
4953
SpecialFunctions = "0.10, 1"
5054
StaticArrays = "0.12, 1"
5155
StatsFuns = "0.9"
52-
TransformVariables = "0.3"
56+
TransformVariables = "0.4"
5357
Tricks = "0.1"
5458
Tullio = "0.2"
5559
julia = "1.5"

docs/src/adding.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ This is by far the most common kind of measure, and is especially useful as a wa
66

77
### Declaring a Parameterized Measure
88

9-
To start, declare a `@measure`. For example, `Normal` is declared as
9+
To start, declare a `@parameterized`. For example, `Normal` is declared as
1010

1111
```julia
12-
@measure Normal(μ,σ) (1/sqrt2π) * Lebesgue(ℝ)
12+
@parameterized Normal(μ,σ) (1/sqrt2π) * Lebesgue(ℝ)
1313
```
1414

1515
[`` is typed as `\bbR <TAB>`]

src/MeasureTheory.jl

+27-19
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ export AbstractMeasure
1818
using InfiniteArrays
1919
using ConcreteStructs
2020
using DynamicIterators
21+
using KeywordCalls
2122

2223
const= InfiniteArrays.∞
2324

@@ -40,42 +41,49 @@ Methods for computing density relative to other measures will be
4041
"""
4142
function logdensity end
4243

43-
include("paramorder.jl")
4444
include("exp.jl")
4545
include("domains.jl")
4646
include("utils.jl")
47+
include("traits.jl")
4748
include("absolutecontinuity.jl")
48-
include("basemeasures.jl")
4949
include("parameterized.jl")
5050
include("macros.jl")
51+
52+
include("primitives/counting.jl")
53+
include("primitives/lebesgue.jl")
54+
include("primitives/dirac.jl")
55+
include("primitives/trivial.jl")
56+
5157
include("combinators/weighted.jl")
5258
include("combinators/superpose.jl")
5359
include("combinators/product.jl")
5460
include("combinators/for.jl")
5561
include("combinators/power.jl")
5662
include("combinators/likelihood.jl")
57-
include("combinators/elementwise.jl")
63+
include("combinators/pointwise.jl")
5864
include("combinators/transforms.jl")
5965
include("combinators/spikemixture.jl")
6066
include("combinators/chain.jl")
67+
6168
include("distributions.jl")
6269
include("rand.jl")
63-
include("probability/dirac.jl")
64-
include("probability/normal.jl")
65-
include("probability/studentt.jl")
66-
include("probability/cauchy.jl")
67-
include("probability/laplace.jl")
68-
include("probability/uniform.jl")
69-
include("probability/beta.jl")
70-
include("probability/gumbel.jl")
71-
include("probability/exponential.jl")
72-
include("probability/mvnormal.jl")
73-
include("probability/inverse-gamma.jl")
74-
include("probability/bernoulli.jl")
75-
include("probability/poisson.jl")
76-
include("probability/binomial.jl")
77-
include("probability/LKJL.jl")
78-
include("probability/negativebinomial.jl")
70+
71+
include("parameterized/normal.jl")
72+
include("parameterized/studentt.jl")
73+
include("parameterized/cauchy.jl")
74+
include("parameterized/laplace.jl")
75+
include("parameterized/uniform.jl")
76+
include("parameterized/beta.jl")
77+
include("parameterized/gumbel.jl")
78+
include("parameterized/exponential.jl")
79+
include("parameterized/mvnormal.jl")
80+
include("parameterized/inverse-gamma.jl")
81+
include("parameterized/bernoulli.jl")
82+
include("parameterized/poisson.jl")
83+
include("parameterized/binomial.jl")
84+
include("parameterized/LKJL.jl")
85+
include("parameterized/negativebinomial.jl")
86+
7987
include("density.jl")
8088
# include("pushforward.jl")
8189
include("kernel.jl")

src/absolutecontinuity.jl

+13-15
Original file line numberDiff line numberDiff line change
@@ -53,26 +53,24 @@ export representative
5353

5454
"""
5555
representative(μ::AbstractMeasure) -> AbstractMeasure
56-
"""
57-
function representative(μ)
58-
function f(μ)
59-
# Check if we're done
60-
isprimitive(μ) && return μ
61-
62-
ν = basemeasure(μ)
6356
64-
# TODO: Make sure we don't leave the equivalence class
65-
# Make sure not to leave the equivalence class
66-
# (ν ≪ μ) || return μ
57+
We need to be able to compute `μ ≪ ν` for each `μ` and `ν`. To do this directly
58+
would require a huge number of methods (quadratic in the number of defined
59+
measures).
6760
68-
return ν
69-
end
61+
This function is a way around that. When defining a new measure `μ`, you should
62+
also find some equivalent measure `ρ` that's "as primitive as possible".
7063
71-
fix(f, μ)
72-
end
64+
If possible, `ρ` should be a `PrimitiveMeasure`, or a `Product` of these. If
65+
not, it should be a transform (`Pushforward` or `Pullback`) of a
66+
`PrimitiveMeasure` (or `Product` of these).
67+
"""
68+
function representative(μ) end
7369

74-
# TODO: ≪ needs more work
7570
function (μ, ν)
7671
μ == ν && return true
7772
representative(μ) representative(ν) && return true
73+
return false
7874
end
75+
76+
@traitfn representative::M) where {M; IsRepresentative{M}} = μ

src/basemeasures.jl

-35
This file was deleted.

src/combinators/chain.jl

-56
Original file line numberDiff line numberDiff line change
@@ -94,59 +94,3 @@ function dyniterate(fr::DynamicFor, state)
9494
u, state = ϕ
9595
fr.f(u), state
9696
end
97-
98-
99-
# using Soss
100-
101-
# hmm = @model begin
102-
# ε ~ Exponential() # transition
103-
# σ ~ Exponential() # Observation noise
104-
# x ~ Chain(Normal()) do xj
105-
# Normal(xj, ε)
106-
# end
107-
108-
# y ~ For(x) do xj
109-
# Normal(xj, σ)
110-
# end
111-
# end
112-
113-
# using Soss
114-
115-
# mbind = @model μ,κ begin
116-
# x ~ μ
117-
# y ~ κ(x)
118-
# return y
119-
# end
120-
121-
# ⋅(μ,κ) = mbind(μ,κ)
122-
123-
# d = Cauchy() ⋅ (x -> Normal(μ=x)) ⋅ (x -> Normal(μ=x)) ⋅ (x -> Normal(μ=x))
124-
125-
# rand(d)
126-
# t = xform(d)
127-
# t(randn(4))
128-
# simulate(d)
129-
130-
# # julia> d = Cauchy() ⋅ (x -> Normal(μ=x)) ⋅ (x -> Normal(μ=x)) ⋅ (x -> Normal(μ=x))
131-
# # ConditionalModel given
132-
# # arguments (:μ, :κ)
133-
# # observations ()
134-
# # @model (μ, κ) begin
135-
# # x ~ μ
136-
# # y ~ κ(x)
137-
# # return y
138-
# # end
139-
140-
141-
142-
# # julia> rand(d)
143-
# # -3.0414465047589037
144-
145-
# # julia> t = xform(d)
146-
# # TransformVariables.TransformTuple{NamedTuple{(:x, :y), Tuple{TransformVariables.TransformTuple{NamedTuple{(:x, :y), Tuple{TransformVariables.TransformTuple{NamedTuple{(:x, :y), Tuple{TransformVariables.Identity, TransformVariables.Identity}}}, TransformVariables.Identity}}}, TransformVariables.Identity}}}((x = TransformVariables.TransformTuple{NamedTuple{(:x, :y), Tuple{TransformVariables.TransformTuple{NamedTuple{(:x, :y), Tuple{TransformVariables.Identity, TransformVariables.Identity}}}, TransformVariables.Identity}}}((x = TransformVariables.TransformTuple{NamedTuple{(:x, :y), Tuple{TransformVariables.Identity, TransformVariables.Identity}}}((x = asℝ, y = asℝ), 2), y = asℝ), 3), y = asℝ), 4)
147-
148-
# # julia> t(randn(4))
149-
# # (x = (x = (x = -0.24259286698966315, y = 0.278190893626807), y = -1.361907586870645), y = 0.05914265096096323)
150-
151-
# # julia> simulate(d)
152-
# # (value = 5.928939554009484, trace = (x = (value = 5.307006358072237, trace = (x = (value = 3.2023770380851797, trace = (x = 3.677550124255551, y = 3.2023770380851797)), y = 5.307006358072237)), y = 5.928939554009484))

src/combinators/elementwise.jl

-63
This file was deleted.

src/combinators/for.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ function Base.rand(rng::AbstractRNG, T::Type, d::ForGenerator)
152152
Base.Generator(r d.data.f, d.data.iter)
153153
end
154154

155-
function MeasureTheory.logdensity(d::ForGenerator, x)
155+
function logdensity(d::ForGenerator, x)
156156
sum((logdensity(dj, xj) for (dj, xj) in zip(d.data, x)))
157157
end
158158

src/combinators/likelihood.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
export Likelihood
1+
export LogLikelihood
22

3-
@concrete terse struct Likelihood{T,X}
3+
@concrete terse struct LogLikelihood{T,X}
44
x::X
55
end
66

7-
Likelihood(T::Type, x::X) where {X} = Likelihood{T,X}(x)
7+
LogLikelihood(T::Type, x::X) where {X} = LogLikelihood{T,X}(x)
88

9-
Likelihood::T, x::X) where {X, T<:AbstractMeasure} = Likelihood{T,X}(x)
9+
LogLikelihood::T, x::X) where {X, T<:AbstractMeasure} = LogLikelihood{T,X}(x)
1010

11-
logdensity(ℓ::Likelihood, p) = (p)
11+
logdensity(ℓ::LogLikelihood, p) = (p)
1212

13-
(ℓ::Likelihood{T,X})(p) where {T,X} = logdensity(T(p), ℓ.x)
13+
(ℓ::LogLikelihood{T,X})(p) where {T,X} = logdensity(T(p), ℓ.x)
1414

15-
(ℓ::Likelihood)(;kwargs...) = ((;kwargs...))
15+
(ℓ::LogLikelihood)(;kwargs...) = ((;kwargs...))

0 commit comments

Comments
 (0)