Skip to content

Add ZeroInflatedPoisson distribution #1393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ version = "0.25.15"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LambertW = "984bce1d-4616-540c-a9ee-88d1112d94c9"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would make sense to define this distribution in StatsFuns rather than using a separate package?

@jlapeyre Would you be OK with that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC there was a discussion about moving it to SpecialFunctions and there even exists a PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes that's JuliaMath/SpecialFunctions.jl#84. Though it's quite outdated now and there have been new commits in LambertW since then.

@emfeltham Do you feel like reviving this PR (or opening a new one)? There seems to be lots of interest in it.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nalimilan (sorry, I just saw a new email ping) Yes, I am definitely OK with moving LambertW into another package. I think the appropriate package is indeed SpecialFunctions. I don't have a pressing interest in doing it myself at the moment. I'm not sure a new attempt at a PR wouldn't fizzle out as well ;)

LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Expand Down
1 change: 1 addition & 0 deletions docs/src/fit.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ The `fit_mle` method has been implemented for the following distributions:
- [`InverseGaussian`](@ref)
- [`Uniform`](@ref)
- [`Weibull`](@ref)
- [`ZeroInflatedPoisson`](@ref)

**Multivariate:**

Expand Down
1 change: 1 addition & 0 deletions docs/src/univariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ NegativeBinomial
Poisson
PoissonBinomial
Skellam
ZeroInflatedPoisson
```

### Vectorized evaluation
Expand Down
6 changes: 4 additions & 2 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import StatsBase: kurtosis, skewness, entropy, mode, modes,

import PDMats: dim, PDMat, invquad

using SpecialFunctions
using SpecialFunctions, LambertW

import ChainRulesCore

Expand Down Expand Up @@ -161,6 +161,7 @@ export
WalleniusNoncentralHypergeometric,
Weibull,
Wishart,
ZeroInflatedPoisson,
ZeroMeanIsoNormal,
ZeroMeanIsoNormalCanon,
ZeroMeanDiagNormal,
Expand Down Expand Up @@ -239,6 +240,7 @@ export
quantile, # inverse of cdf (defined for p in (0,1))
qqbuild, # build a paired quantiles data structure for qqplots
rate, # get the rate parameter
excessprob, # get the exess probability of zeros parameter (ZeroInflatedPoison)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we should introduce and in particular export a new function here just for ZeroInflatedPoisson.

sampler, # create a Sampler object for efficient samples
scale, # get the scale parameter
scale!, # provide storage for the scale parameter (used in multivariate distribution mvlognormal)
Expand Down Expand Up @@ -334,7 +336,7 @@ Supported distributions:
QQPair, Rayleigh, Skellam, Soliton, StudentizedRange, SymTriangularDist, TDist, TriangularDist,
Triweight, Truncated, TruncatedNormal, Uniform, UnivariateGMM,
VonMises, VonMisesFisher, WalleniusNoncentralHypergeometric, Weibull,
Wishart, ZeroMeanIsoNormal, ZeroMeanIsoNormalCanon,
Wishart, ZeroInflatedPoisson, ZeroMeanIsoNormal, ZeroMeanIsoNormalCanon,
ZeroMeanDiagNormal, ZeroMeanDiagNormalCanon, ZeroMeanFullNormal,
ZeroMeanFullNormalCanon

Expand Down
165 changes: 165 additions & 0 deletions src/univariate/discrete/zeroinflatedpoisson.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
ZeroInflatedPoisson(λ, p)
A *Zero-Inflated Poisson distribution* is a mixture distribution in which data arise from two processes. The first process is is a Poisson distribution, with mean λ, that descibes the number of independent events occurring within a unit time interval:
```math
P(X = k) = (1 - p) \\frac{\\lambda^k}{k!} e^{-\\lambda}, \\quad \\text{ for } k = 0,1,2,\\ldots.
```
Zeros may arise from this process, an additional Bernoulli process, where the probability of observing an excess zero is given as p:
```math
P(X = 0) = p + (1 - p) e^{-\\lambda}
```
As p approaches 0, the distribution tends toward Poisson(λ).
```julia
ZeroInflatedPoisson() # Zero-Inflated Poisson distribution with rate parameter 1, and probability of observing a zero 0.5
ZeroInflatedPoisson(λ) # ZeroInflatedPoisson distribution with rate parameter λ, and probability of observing a zero 0.5
params(d) # Get the parameters, i.e. (λ, p)
mean(d) # Get the mean of the mixture distribution
var(d) # Get the variance of the mixture distribution
```
External links:
* [Zero-inflated Poisson Regression on UCLA IDRE Statistical Consulting](https://stats.idre.ucla.edu/stata/dae/zero-inflated-poisson-regression/)
* [Zero-inflated model on Wikipedia](https://en.wikipedia.org/wiki/Zero-inflated_model)
* McElreath, R. (2020). Statistical Rethinking: A Bayesian Course with Examples in R and Stan (2nd ed.). Chapman and Hall/CRC. https://doi.org/10.1201/9780429029608

"""
struct ZeroInflatedPoisson{T<:Real} <: DiscreteUnivariateDistribution
λ::T
p::T

function ZeroInflatedPoisson{T}(λ::T, p::T) where {T <: Real}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you adapt your indentation to 4 spaces?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

return new{T}(λ, p)
end
end

function ZeroInflatedPoisson(λ::T, p::T; check_args = true) where {T <: Real}
if check_args
@check_args(Poisson, λ >= zero(λ))
@check_args(ZeroInflatedPoisson, zero(p) <= p <= one(p))
end
return ZeroInflatedPoisson{T}(λ, p)
end

ZeroInflatedPoisson(λ::Real, p::Real) = ZeroInflatedPoisson(promote(λ, p)...)
ZeroInflatedPoisson(λ::Integer, p::Integer) = ZeroInflatedPoisson(float(λ), float(p))
ZeroInflatedPoisson(λ::Real) = ZeroInflatedPoisson(λ, 0.0)
ZeroInflatedPoisson() = ZeroInflatedPoisson(1.0, 0.0, check_args = false)

@distr_support ZeroInflatedPoisson 0 (d.λ == zero(typeof(d.λ)) ? 0 : Inf)

### Statistics

mean(d::ZeroInflatedPoisson) = (1 - d.p) * d.λ

var(d::ZeroInflatedPoisson) = d.λ * (1 - d.p) * (1 + d.p * d.λ)

#### Conversions

function convert(::Type{ZeroInflatedPoisson{T}}, λ::Real, p::Real) where {T<:Real}
return ZeroInflatedPoisson(T(λ), T(p))
end

function convert(::Type{ZeroInflatedPoisson{T}}, d::ZeroInflatedPoisson{S}) where {T <: Real, S <: Real}
return ZeroInflatedPoisson(T(d.λ), T(d.p), check_args = false)
end

#### Parameters

params(d::ZeroInflatedPoisson) = (d.λ, d.p,)
partype(::ZeroInflatedPoisson{T}) where {T} = T

rate(d::ZeroInflatedPoisson) = d.λ
excessprob(d::ZeroInflatedPoisson) = d.p

#### Evaluation

function logpdf(d::ZeroInflatedPoisson, y::Real)
lp = if iszero(y)
logaddexp(log(d.p), log1p(-d.p) - d.λ)
else
log1p(-d.p) + logpdf(Poisson(d.λ), y)
end
return lp
end

function cdf(d::ZeroInflatedPoisson, x::Real)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is type unstable (eg. if the parameters and x are of type Float32). The best approach is to perform all calculations and just return zero(result) or oftype(result, NaN) if necessary for some values of x in the end.

pd = Poisson(d.λ)

deflat_limit = -1.0 / expm1(d.λ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid harcoded Float64 literals since they might cause unwanted promotions.


if x < 0
out = 0.0
elseif d.p < deflat_limit
out = NaN
else
out = d.p + (1 - d.p) * cdf(pd, x)
end
return out
end

# quantile
function quantile(d::ZeroInflatedPoisson, q::Real)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function has the same problems as cdf.


deflat_limit = -1.0 / expm1(d.λ)

if (q <= d.p)
out = 0
elseif (d.p < deflat_limit)
out = convert(Int64, NaN)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will throw an InexactError. In general, one should also avoid hardcoding Int64 and use Int (defaults to Int64 on 64bit) if possible since Int64 can lead to a mixup of Int32 and Int64 on 32bit machines.

elseif (d.p < q) & (deflat_limit <= d.p) & (q < 1.0)
qp = (q - d.p) / (1.0 - d.p)
pd = Poisson(d.λ)
out = quantile(pd, qp) # handles d.p == 1 as InexactError(Inf)
end
return out
end

#### Fitting

struct ZeroInflatedPoissonStats <: SufficientStats
sx::Float64 # (weighted) sum of x
p0::Float64 # observed proportion of zeros
tw::Float64 # total sample weight
Comment on lines +119 to +121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The types seem a bit restrictive.

end

suffstats(::Type{<:ZeroInflatedPoisson}, x::AbstractArray{T}) where {T<:Integer} = ZeroInflatedPoissonStats(
sum(x),
mean(iszero, x),
length(x)
)

# weighted
function suffstats(::Type{<:ZeroInflatedPoisson}, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Integer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Float64 type parameter seems a bit restrictive.

n = length(x)
n == length(w) || throw(DimensionMismatch("Inconsistent array lengths."))
sx = 0.
tw = 0.
p0 = 0.
for i = 1 : n
@inbounds wi = w[i]
@inbounds sx += x[i] * wi
tw += wi
@inbounds p0i = (x[i] == 0) * wi
p0 += p0i
end
return ZeroInflatedPoissonStats(sx, p0, tw)
end

function fit_mle(::Type{<:ZeroInflatedPoisson}, ss::ZeroInflatedPoissonStats)
m = ss.sx / ss.tw
s = m / (1 - ss.p0)

λhat = lambertw(-s * exp(-s)) + s
phat = 1 - (m / λhat)

return ZeroInflatedPoisson(λhat, phat)
end

function fit_mle(::Type{<:ZeroInflatedPoisson}, x::AbstractArray{T}) where T<:Real
pstat = suffstats(ZeroInflatedPoisson, x)
return fit_mle(ZeroInflatedPoisson, pstat)
end

function fit_mle(::Type{<:ZeroInflatedPoisson}, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Real
pstat = suffstats(ZeroInflatedPoisson, x, w)
return fit_mle(ZeroInflatedPoisson, pstat)
end
3 changes: 2 additions & 1 deletion src/univariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,8 @@ const discrete_distributions = [
"poisson",
"skellam",
"soliton",
"poissonbinomial"
"poissonbinomial",
"zeroinflatedpoisson"
]

const continuous_distributions = [
Expand Down
32 changes: 32 additions & 0 deletions test/ref/discrete/zeroinflatedpoisson.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

ZeroInflatedPoisson <- R6Class("ZeroInflatedPoisson",
inherit = DiscreteDistribution,
public = list(
names = c("lambda", "p"),
lambda = NA,
p = NA,
initialize = function(lambda = 1, p = 0) {
self$lambda <- lambda
self$p <- p
},
supp = function() { c(0, Inf) },
properties = function() {
lam <- self$lambda
p <- self$p
list(rate = lam,
excessprob = p,
mean = (1 - p) * lam,
var = lam * (1 - p) * (1 + p * lam)
)
},
pdf = function(x, log=FALSE) {
VGAM::dzipois(x, self$lambda, pstr0 = self$p, log = log)
},
cdf = function(x) {
VGAM::pzipois(x, self$lambda, pstr0 = self$p)
},
quan = function(v) {
VGAM::qzipois(v, self$lambda, pstr0 = self$p)
}
)
)
11 changes: 11 additions & 0 deletions test/ref/discrete_test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,14 @@ WalleniusNoncentralHypergeometric(8, 6, 10, 0.1)
WalleniusNoncentralHypergeometric(40, 30, 50, 1)
WalleniusNoncentralHypergeometric(40, 30, 50, 0.5)
WalleniusNoncentralHypergeometric(40, 30, 50, 2)

ZeroInflatedPoisson()
ZeroInflatedPoisson(1.0)
ZeroInflatedPoisson(0.5, 0.0)
ZeroInflatedPoisson(0.5, 1.0)
ZeroInflatedPoisson(2.0, 0.0)
ZeroInflatedPoisson(2.0, 1.0)
ZeroInflatedPoisson(10.0, 0.0)
ZeroInflatedPoisson(10.0, 1.0)
ZeroInflatedPoisson(80.0, 0.0)
ZeroInflatedPoisson(80.0, 1.0)
1 change: 1 addition & 0 deletions test/ref/rdistributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ source("discrete/negativebinomial.R")
source("discrete/noncentralhypergeometric.R")
source("discrete/poisson.R")
source("discrete/skellam.R")
source("discrete/zeroinflatedpoisson.R")

#################################################
#
Expand Down
2 changes: 1 addition & 1 deletion test/ref/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ in addition to the R language itself:
| stringr | For string parsing |
| R6 | OOP for implementing distributions |
| extraDistr | A number of distributions |
| VGAM | For ``Frechet`` and ``Levy`` |
| VGAM | For ``Frechet`` and ``Levy`` and ``ZeroInflatedPoisson``|
| distr | For ``Arcsine`` |
| chi | For ``Chi`` |
| circular | For ``VonMises`` |
Expand Down