Skip to content

Commit ff7b123

Browse files
committed
added GammaNormal tests and fixed some deprecations
1 parent 0358537 commit ff7b123

File tree

5 files changed

+133
-93
lines changed

5 files changed

+133
-93
lines changed

src/distributions.jl

Lines changed: 120 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -13,136 +13,179 @@ abstract type ContinuousMultivariateConjugatePostDistribution <: MultivariateCon
1313
# Gaussian with Normal Inverse Wishart Prior
1414
mutable struct WishartGaussian <: ContinuousMultivariateConjugatePostDistribution
1515

16-
D::Int
16+
D::Int
1717

18-
# sufficient statistics
19-
n::Int
20-
sums::Vector{Float64}
21-
ssums::Array{Float64}
18+
# sufficient statistics
19+
n::Int
20+
sums::Vector{Float64}
21+
ssums::Array{Float64}
2222

23-
# base model parameters
24-
μ0::Vector{Float64}
25-
κ0::Float64
26-
ν0::Float64
27-
Σ0::Array{Float64}
23+
# base model parameters
24+
μ0::Vector{Float64}
25+
κ0::Float64
26+
ν0::Float64
27+
Σ0::Array{Float64}
2828

29-
function WishartGaussian(μ0::Vector{Float64}, κ0::Float64,
30-
ν0::Float64, Σ0::Array{Float64})
29+
end
30+
31+
"""
32+
WishartGaussian(μ0, κ0, ν0, Σ0)
33+
34+
## Gaussian-inverse-Wishart distribution
35+
A Gaussian-inverse-Wishart distribution is the conjugate prior of a multivariate normal distribution with unknown mean and covariance matrix.
36+
37+
## Parameters
38+
* `μ0, Dx1`: location
39+
* `κ0 > 0`: number of pseudo-observations
40+
* `ν0 > D-1`: degrees of freedom
41+
* `Σ0 > 0, DxD`: scale matrix
3142
32-
d = length(μ0)
33-
new(d, 0, zeros(d), zeros(d, d), μ0, κ0, ν0, Σ0)
34-
end
43+
## Example
44+
```julia-repl
45+
julia> (N, D) = size(X)
46+
julia> μ0 = mean(X, dims = 1)
47+
julia> d = WishartGaussian(μ0, 1.0, 2*D, cov(x))
48+
```
3549
50+
"""
51+
function WishartGaussian(μ0::Vector{Float64}, κ0::Float64,
52+
ν0::Float64, Σ0::Array{Float64})
53+
54+
d = length(μ0)
55+
(D1, D2) = size(Σ0)
56+
@assert D1 == D2
57+
@assert D1 == d
58+
59+
WishartGaussian(d, 0, zeros(d), zeros(d, d), μ0, κ0, ν0, Σ0)
3660
end
3761

62+
3863
# Normal with Gamma prior
3964
mutable struct GammaNormal <: ContinuousUnivariateConjugatePostDistribution
4065

41-
# sufficient statistics
42-
n::Int
43-
sums::Float64
44-
ssums::Float64
66+
# sufficient statistics
67+
n::Int
68+
sums::Float64
69+
ssums::Float64
4570

46-
# model parameters
47-
μ0::Float64
48-
λ0::Float64
49-
α0::Float64
50-
β0::Float64
71+
# model parameters
72+
μ0::Float64
73+
λ0::Float64
74+
α0::Float64
75+
β0::Float64
5176

52-
function GammaNormal(;μ0 = 0.0, λ0 = 1.0, α0 = 1.0, β0 = 1.0)
53-
new(0, 0.0, 0.0, μ0, λ0, α0, β0)
54-
end
77+
end
5578

79+
"""
80+
GammaNormal(; μ0 = 0.0, λ0 = 1.0, α0 = 1.0, β0 = 1.0)
81+
82+
## Normal-Gamma distribution
83+
A Normal-Gamma distribution is the conjugate prior of a Normal distribution
84+
with unknown mean and precision.
85+
86+
## Paramters
87+
* `μ0`: location
88+
* `λ0 > 0`: number of pseudo-observations
89+
* `α0 > 0`
90+
* `β0 > 0`
91+
92+
Example:
93+
```julia
94+
d = GammaNormal()
95+
```
96+
"""
97+
function GammaNormal(;μ0 = 0.0, λ0 = 1.0, α0 = 1.0, β0 = 1.0)
98+
GammaNormal(0, 0.0, 0.0, μ0, λ0, α0, β0)
5699
end
57100

58101
# Normal with Normal prior
59102
mutable struct NormalNormal <: ContinuousUnivariateConjugatePostDistribution
60103

61-
# sufficient statistics
62-
n::Int
63-
sums::Float64
64-
ssums::Float64
104+
# sufficient statistics
105+
n::Int
106+
sums::Float64
107+
ssums::Float64
65108

66-
# model parameters
67-
μ0::Float64
68-
σ0::Float64
109+
# model parameters
110+
μ0::Float64
111+
σ0::Float64
69112

70-
function NormalNormal(;μ0 = 0.0, σ0 = 1.0)
71-
new(0, 0.0, 0.0, μ0, σ0)
72-
end
113+
function NormalNormal(;μ0 = 0.0, σ0 = 1.0)
114+
new(0, 0.0, 0.0, μ0, σ0)
115+
end
73116

74117
end
75118

76119
# Gaussian with Diagonal Covariance
77120
mutable struct GaussianDiagonal{T <: ContinuousUnivariateConjugatePostDistribution} <: ContinuousMultivariateConjugatePostDistribution
78121

79-
# sufficient statistics
80-
dists::Vector{T}
122+
# sufficient statistics
123+
dists::Vector{T}
81124

82-
# isn't the default constructor sufficient here?
83-
#function GaussianDiagonal(dists::Vector{T})
84-
# new(dists)
85-
#end
125+
# isn't the default constructor sufficient here?
126+
#function GaussianDiagonal(dists::Vector{T})
127+
# new(dists)
128+
#end
86129

87130
end
88131

89132
# Multinomial with Dirichlet Prior
90133
mutable struct DirichletMultinomial <: DiscreteMultivariateConjugatePostDistribution
91134

92-
D::Int
135+
D::Int
93136

94-
# sufficient statistics
95-
n::Int
96-
counts::SparseMatrixCSC{Int,Int}
137+
# sufficient statistics
138+
n::Int
139+
counts::SparseMatrixCSC{Int,Int}
97140

98-
# base model parameters
99-
α0::Float64
141+
# base model parameters
142+
α0::Float64
100143

101-
# cache
102-
dirty::Bool
103-
Z2::Float64
104-
Z3::Array{Float64}
144+
# cache
145+
dirty::Bool
146+
Z2::Float64
147+
Z3::Array{Float64}
105148

106-
function DirichletMultinomial(D::Int, α0::Float64)
107-
new(D, 0, sparsevec(zeros(D)), α0, true, 0.0, Array{Float64}(0))
108-
end
149+
function DirichletMultinomial(D::Int, α0::Float64)
150+
new(D, 0, sparsevec(zeros(D)), α0, true, 0.0, Array{Float64}(0))
151+
end
109152

110153
end
111154

112155
# Categorical with Dirichlet Prior
113156
mutable struct DirichletCategorical <: DiscreteUnivariateConjugatePostDistribution
114157

115-
# sufficient statistics
116-
n::Int
117-
counts::SparseMatrixCSC{Int,Int}
158+
# sufficient statistics
159+
n::Int
160+
counts::SparseMatrixCSC{Int,Int}
118161

119-
# base model parameters
120-
α0::Float64
162+
# base model parameters
163+
α0::Float64
121164

122-
# cache
123-
dirty::Bool
124-
Z2::Float64
125-
Z3::Array{Float64}
165+
# cache
166+
dirty::Bool
167+
Z2::Float64
168+
Z3::Array{Float64}
126169

127-
function DirichletMultinomial(D::Int, α0::Float64)
128-
new(D, 0, sparsevec(zeros(D)), α0, true, 0.0, Array{Float64}(0))
129-
end
170+
function DirichletMultinomial(D::Int, α0::Float64)
171+
new(D, 0, sparsevec(zeros(D)), α0, true, 0.0, Array{Float64}(0))
172+
end
130173

131174
end
132175

133176
# Bernoulli with Beta Prior
134177
mutable struct BetaBernoulli <: DiscreteUnivariateConjugatePostDistribution
135178

136-
# sufficient statistics
137-
successes::Int
138-
n::Int
179+
# sufficient statistics
180+
successes::Int
181+
n::Int
139182

140-
# beta distribution parameters
141-
α0::Float64
142-
β0::Float64
183+
# beta distribution parameters
184+
α0::Float64
185+
β0::Float64
143186

144-
function BetaBernoulli(;α0 = 1.0, β0 = 1.0)
145-
new(0, 0, α0, β0)
146-
end
187+
function BetaBernoulli(;α0 = 1.0, β0 = 1.0)
188+
new(0, 0, α0, β0)
189+
end
147190

148191
end

src/dpmm.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ function gibbs!(B::DPMBuffer)
211211
if k > length(B.G)
212212
# add new cluster
213213
Gk = add(B.G0, x)
214-
B.G = cat(1, B.G, Gk)
215-
B.C = cat(1, B.C, 0)
214+
B.G = vcat(B.G, Gk)
215+
B.C = vcat(B.C, 0)
216216
else
217217
# add to cluster
218218
add!(B.G[k], x)
@@ -234,7 +234,7 @@ end
234234
"Compute Energy of model for given data"
235235
function updateenergy!(B::DPMData, X::AbstractArray)
236236

237-
E = 0.00001
237+
E = eps() * 10^10
238238

239239
for xi in 1:size(X, 1)
240240

src/hdp.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,12 @@ function gibbs!(B::HDPBuffer)
203203
if c > B.K
204204
# create new cluster
205205
B.K += 1
206-
B.G = cat(1, B.G, deepcopy(B.G0))
206+
B.G = vcat(B.G, deepcopy(B.G0))
207207
b = rand(Dirichlet([1, B.γ]))
208208
b = b * B.β[end]
209-
B.β = cat(1, B.β, 1)
209+
B.β = vcat(B.β, 1)
210210
B.β[end-1:end] = b
211-
B.C = cat(1, B.C, zeros(Int, 1, B.N0))
211+
B.C = vcat(B.C, zeros(Int, 1, B.N0))
212212
prob = zeros(B.K + 1) * -Inf
213213
end
214214

@@ -221,7 +221,7 @@ function gibbs!(B::HDPBuffer)
221221

222222
# sample number of tables
223223
kk = maximum([0, B.K - length(B.totalnt)])
224-
B.totalnt = cat(2, B.totalnt - sum(B.classnt, 1), zeros(Int, 1, kk))
224+
B.totalnt = hcat(B.totalnt - sum(B.classnt, 1), zeros(Int, 1, kk))
225225
B.classnt = randnumtable(B.α .* B.β[:,ones(Int, B.N0)]', B.C')
226226
B.totalnt = B.totalnt + sum(B.classnt, 1)
227227

test/distributionTests.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ using LinearAlgebra
6565
N = length(x)
6666

6767
# distribution
68-
d = GammaNormal(μ0, λ0, α0, β0)
68+
d = GammaNormal(μ0 = μ0, λ0 = λ0, α0 = α0, β0 = β0)
6969

7070
# test prior
7171
(μ, λ, α, β) = BayesianNonparametrics.posteriorParameters(d)
@@ -79,13 +79,14 @@ using LinearAlgebra
7979
BayesianNonparametrics.add!(d, x)
8080

8181
# test posterior paramters
82+
# see: https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf page 8.
8283
(μ, λ, α, β) = BayesianNonparametrics.posteriorParameters(d)
8384

84-
#@test μ ==
8585
@test λ == λ0 + N
86-
@test α = α0 + (N / 2)
87-
#@test β ==
88-
86+
@test α == α0 + (N / 2)
87+
@test μ == (λ0 * μ0 + N * mean(x)) / (λ0 + N)
88+
@test β == β0 + 1/2 * sum( (x .- mean(x)).^2 ) + ( λ0 * N * (mean(x) - μ0)^2 ) / (2 * λ)
89+
8990
# remove data
9091
BayesianNonparametrics.remove!(d, x)
9192

test/dpmTests.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,8 @@ modelBuffer = init(X, model, initialisation)
1717
model0 = BayesianNonparametrics.extractpointestimate(modelBuffer)
1818
model1 = train(modelBuffer, DPMHyperparam(), Gibbs(maxiter = 1))[end]
1919

20-
@test model0.energy < model1.energy
21-
2220
initialisation = KMeansInitialisation(k = 10)
2321
modelBuffer = init(X, model, initialisation)
2422

2523
model0 = BayesianNonparametrics.extractpointestimate(modelBuffer)
2624
model1 = train(modelBuffer, DPMHyperparam(), Gibbs(maxiter = 1))[end]
27-
28-
@test model0.energy < model1.energy

0 commit comments

Comments
 (0)