@@ -13,136 +13,179 @@ abstract type ContinuousMultivariateConjugatePostDistribution <: MultivariateCon
1313# Gaussian with Normal Inverse Wishart Prior
1414mutable struct WishartGaussian <: ContinuousMultivariateConjugatePostDistribution
1515
16- D:: Int
16+ D:: Int
1717
18- # sufficient statistics
19- n:: Int
20- sums:: Vector{Float64}
21- ssums:: Array{Float64}
18+ # sufficient statistics
19+ n:: Int
20+ sums:: Vector{Float64}
21+ ssums:: Array{Float64}
2222
23- # base model parameters
24- μ0:: Vector{Float64}
25- κ0:: Float64
26- ν0:: Float64
27- Σ0:: Array{Float64}
23+ # base model parameters
24+ μ0:: Vector{Float64}
25+ κ0:: Float64
26+ ν0:: Float64
27+ Σ0:: Array{Float64}
2828
29- function WishartGaussian (μ0:: Vector{Float64} , κ0:: Float64 ,
30- ν0:: Float64 , Σ0:: Array{Float64} )
29+ end
30+
31+ """
32+ WishartGaussian(μ0, κ0, ν0, Σ0)
33+
34+ ## Gaussian-inverse-Wishart distribution
35+ A Gaussian-inverse-Wishart distribution is the conjugate prior of a multivariate normal distribution with unknown mean and covariance matrix.
36+
37+ ## Parameters
38+ * `μ0, Dx1`: location
39+ * `κ0 > 0`: number of pseudo-observations
40+ * `ν0 > D-1`: degrees of freedom
41+ * `Σ0 > 0, DxD`: scale matrix
3142
32- d = length (μ0)
33- new (d, 0 , zeros (d), zeros (d, d), μ0, κ0, ν0, Σ0)
34- end
43+ ## Example
44+ ```julia-repl
45+ julia> (N, D) = size(X)
46+ julia> μ0 = mean(X, dims = 1)
47+ julia> d = WishartGaussian(μ0, 1.0, 2*D, cov(x))
48+ ```
3549
50+ """
51+ function WishartGaussian (μ0:: Vector{Float64} , κ0:: Float64 ,
52+ ν0:: Float64 , Σ0:: Array{Float64} )
53+
54+ d = length (μ0)
55+ (D1, D2) = size (Σ0)
56+ @assert D1 == D2
57+ @assert D1 == d
58+
59+ WishartGaussian (d, 0 , zeros (d), zeros (d, d), μ0, κ0, ν0, Σ0)
3660end
3761
62+
3863# Normal with Gamma prior
3964mutable struct GammaNormal <: ContinuousUnivariateConjugatePostDistribution
4065
41- # sufficient statistics
42- n:: Int
43- sums:: Float64
44- ssums:: Float64
66+ # sufficient statistics
67+ n:: Int
68+ sums:: Float64
69+ ssums:: Float64
4570
46- # model parameters
47- μ0:: Float64
48- λ0:: Float64
49- α0:: Float64
50- β0:: Float64
71+ # model parameters
72+ μ0:: Float64
73+ λ0:: Float64
74+ α0:: Float64
75+ β0:: Float64
5176
52- function GammaNormal (;μ0 = 0.0 , λ0 = 1.0 , α0 = 1.0 , β0 = 1.0 )
53- new (0 , 0.0 , 0.0 , μ0, λ0, α0, β0)
54- end
77+ end
5578
79+ """
80+ GammaNormal(; μ0 = 0.0, λ0 = 1.0, α0 = 1.0, β0 = 1.0)
81+
82+ ## Normal-Gamma distribution
83+ A Normal-Gamma distribution is the conjugate prior of a Normal distribution
84+ with unknown mean and precision.
85+
86+ ## Paramters
87+ * `μ0`: location
88+ * `λ0 > 0`: number of pseudo-observations
89+ * `α0 > 0`
90+ * `β0 > 0`
91+
92+ Example:
93+ ```julia
94+ d = GammaNormal()
95+ ```
96+ """
97+ function GammaNormal (;μ0 = 0.0 , λ0 = 1.0 , α0 = 1.0 , β0 = 1.0 )
98+ GammaNormal (0 , 0.0 , 0.0 , μ0, λ0, α0, β0)
5699end
57100
58101# Normal with Normal prior
59102mutable struct NormalNormal <: ContinuousUnivariateConjugatePostDistribution
60103
61- # sufficient statistics
62- n:: Int
63- sums:: Float64
64- ssums:: Float64
104+ # sufficient statistics
105+ n:: Int
106+ sums:: Float64
107+ ssums:: Float64
65108
66- # model parameters
67- μ0:: Float64
68- σ0:: Float64
109+ # model parameters
110+ μ0:: Float64
111+ σ0:: Float64
69112
70- function NormalNormal (;μ0 = 0.0 , σ0 = 1.0 )
71- new (0 , 0.0 , 0.0 , μ0, σ0)
72- end
113+ function NormalNormal (;μ0 = 0.0 , σ0 = 1.0 )
114+ new (0 , 0.0 , 0.0 , μ0, σ0)
115+ end
73116
74117end
75118
76119# Gaussian with Diagonal Covariance
77120mutable struct GaussianDiagonal{T <: ContinuousUnivariateConjugatePostDistribution } <: ContinuousMultivariateConjugatePostDistribution
78121
79- # sufficient statistics
80- dists:: Vector{T}
122+ # sufficient statistics
123+ dists:: Vector{T}
81124
82- # isn't the default constructor sufficient here?
83- # function GaussianDiagonal(dists::Vector{T})
84- # new(dists)
85- # end
125+ # isn't the default constructor sufficient here?
126+ # function GaussianDiagonal(dists::Vector{T})
127+ # new(dists)
128+ # end
86129
87130end
88131
89132# Multinomial with Dirichlet Prior
90133mutable struct DirichletMultinomial <: DiscreteMultivariateConjugatePostDistribution
91134
92- D:: Int
135+ D:: Int
93136
94- # sufficient statistics
95- n:: Int
96- counts:: SparseMatrixCSC{Int,Int}
137+ # sufficient statistics
138+ n:: Int
139+ counts:: SparseMatrixCSC{Int,Int}
97140
98- # base model parameters
99- α0:: Float64
141+ # base model parameters
142+ α0:: Float64
100143
101- # cache
102- dirty:: Bool
103- Z2:: Float64
104- Z3:: Array{Float64}
144+ # cache
145+ dirty:: Bool
146+ Z2:: Float64
147+ Z3:: Array{Float64}
105148
106- function DirichletMultinomial (D:: Int , α0:: Float64 )
107- new (D, 0 , sparsevec (zeros (D)), α0, true , 0.0 , Array {Float64} (0 ))
108- end
149+ function DirichletMultinomial (D:: Int , α0:: Float64 )
150+ new (D, 0 , sparsevec (zeros (D)), α0, true , 0.0 , Array {Float64} (0 ))
151+ end
109152
110153end
111154
112155# Categorical with Dirichlet Prior
113156mutable struct DirichletCategorical <: DiscreteUnivariateConjugatePostDistribution
114157
115- # sufficient statistics
116- n:: Int
117- counts:: SparseMatrixCSC{Int,Int}
158+ # sufficient statistics
159+ n:: Int
160+ counts:: SparseMatrixCSC{Int,Int}
118161
119- # base model parameters
120- α0:: Float64
162+ # base model parameters
163+ α0:: Float64
121164
122- # cache
123- dirty:: Bool
124- Z2:: Float64
125- Z3:: Array{Float64}
165+ # cache
166+ dirty:: Bool
167+ Z2:: Float64
168+ Z3:: Array{Float64}
126169
127- function DirichletMultinomial (D:: Int , α0:: Float64 )
128- new (D, 0 , sparsevec (zeros (D)), α0, true , 0.0 , Array {Float64} (0 ))
129- end
170+ function DirichletMultinomial (D:: Int , α0:: Float64 )
171+ new (D, 0 , sparsevec (zeros (D)), α0, true , 0.0 , Array {Float64} (0 ))
172+ end
130173
131174end
132175
133176# Bernoulli with Beta Prior
134177mutable struct BetaBernoulli <: DiscreteUnivariateConjugatePostDistribution
135178
136- # sufficient statistics
137- successes:: Int
138- n:: Int
179+ # sufficient statistics
180+ successes:: Int
181+ n:: Int
139182
140- # beta distribution parameters
141- α0:: Float64
142- β0:: Float64
183+ # beta distribution parameters
184+ α0:: Float64
185+ β0:: Float64
143186
144- function BetaBernoulli (;α0 = 1.0 , β0 = 1.0 )
145- new (0 , 0 , α0, β0)
146- end
187+ function BetaBernoulli (;α0 = 1.0 , β0 = 1.0 )
188+ new (0 , 0 , α0, β0)
189+ end
147190
148191end
0 commit comments