Skip to content

Commit fa74d14

Browse files
authored
Allow for sparse arrays and views (#446)
1 parent be96833 commit fa74d14

File tree

5 files changed

+91
-15
lines changed

5 files changed

+91
-15
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GLM"
22
uuid = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
3-
version = "1.6.0"
3+
version = "1.6.1"
44

55
[deps]
66
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/glmfit.jl

+12-8
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,15 @@ function GlmResp(y::V, d::D, l::L, η::V, μ::V, off::V, wts::V) where {V<:FPVec
4949
return GlmResp{V,D,L}(y, d, similar(y), η, μ, off, wts, similar(y), similar(y))
5050
end
5151

52-
function GlmResp(y::V, d::D, l::L, off::V, wts::V) where {V<:FPVector,D,L}
53-
η = similar(y)
54-
μ = similar(y)
55-
r = GlmResp(y, d, l, η, μ, off, wts)
56-
initialeta!(r.eta, d, l, y, wts, off)
52+
function GlmResp(y::FPVector, d::Distribution, l::Link, off::FPVector, wts::FPVector)
53+
# Instead of convert(Vector{Float64}, y) to be more ForwardDiff friendly
54+
_y = convert(Vector{float(eltype(y))}, y)
55+
_off = convert(Vector{float(eltype(off))}, off)
56+
_wts = convert(Vector{float(eltype(wts))}, wts)
57+
η = similar(_y)
58+
μ = similar(_y)
59+
r = GlmResp(_y, d, l, η, μ, _off, _wts)
60+
initialeta!(r.eta, d, l, _y, _wts, _off)
5761
updateμ!(r, r.eta)
5862
return r
5963
end
@@ -465,14 +469,14 @@ Fit a generalized linear model to data.
465469
$FIT_GLM_DOC
466470
"""
467471
function fit(::Type{M},
468-
X::Union{Matrix{T},SparseMatrixCSC{T}},
472+
X::AbstractMatrix{<:FP},
469473
y::AbstractVector{<:Real},
470474
d::UnivariateDistribution,
471475
l::Link = canonicallink(d);
472476
dofit::Bool = true,
473477
wts::AbstractVector{<:Real} = similar(y, 0),
474478
offset::AbstractVector{<:Real} = similar(y, 0),
475-
fitargs...) where {M<:AbstractGLM,T<:FP}
479+
fitargs...) where {M<:AbstractGLM}
476480

477481
# Check that X and y have the same number of observations
478482
if size(X, 1) != size(y, 1)
@@ -485,7 +489,7 @@ function fit(::Type{M},
485489
end
486490

487491
fit(::Type{M},
488-
X::Union{Matrix,SparseMatrixCSC},
492+
X::AbstractMatrix,
489493
y::AbstractVector,
490494
d::UnivariateDistribution,
491495
l::Link=canonicallink(d); kwargs...) where {M<:AbstractGLM} =

src/linpred.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ function SparsePredChol(X::SparseMatrixCSC{T}) where T
185185
similar(X))
186186
end
187187

188-
cholpred(X::SparseMatrixCSC) = SparsePredChol(X)
188+
cholpred(X::SparseMatrixCSC, pivot::Bool=false) = SparsePredChol(X)
189189

190190
function delbeta!(p::SparsePredChol{T}, r::Vector{T}, wt::Vector{T}) where T
191191
scr = mul!(p.scratch, Diagonal(wt), p.X)
@@ -194,6 +194,13 @@ function delbeta!(p::SparsePredChol{T}, r::Vector{T}, wt::Vector{T}) where T
194194
p.delbeta = c \ mul!(p.delbeta, adjoint(scr), r)
195195
end
196196

197+
function delbeta!(p::SparsePredChol{T}, r::Vector{T}) where T
198+
scr = p.scratch = p.X
199+
XtWX = p.Xt*scr
200+
c = p.chol = cholesky(Symmetric{eltype(XtWX),typeof(XtWX)}(XtWX, 'L'))
201+
p.delbeta = c \ mul!(p.delbeta, adjoint(scr), r)
202+
end
203+
197204
LinearAlgebra.cholesky(p::SparsePredChol{T}) where {T} = copy(p.chol)
198205
LinearAlgebra.cholesky!(p::SparsePredChol{T}) where {T} = p.chol
199206

src/lm.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,16 @@ mutable struct LmResp{V<:FPVector} <: ModResp # response in a linear model
2828
end
2929
end
3030

31-
LmResp(y::FPVector, wts::FPVector=similar(y, 0)) =
32-
LmResp{typeof(y)}(fill!(similar(y), 0), similar(y, 0), wts, y)
33-
34-
LmResp(y::AbstractVector{<:Real}, wts::AbstractVector{<:Real}=similar(y, 0)) =
35-
LmResp(float(y), float(wts))
31+
function LmResp(y::AbstractVector{<:Real}, wts::Union{Nothing,AbstractVector{<:Real}}=nothing)
32+
# Instead of convert(Vector{Float64}, y) to be more ForwardDiff friendly
33+
_y = convert(Vector{float(eltype(y))}, y)
34+
_wts = if wts === nothing
35+
similar(_y, 0)
36+
else
37+
convert(Vector{float(eltype(wts))}, wts)
38+
end
39+
return LmResp{typeof(_y)}(zero(_y), zero(_y), _wts, _y)
40+
end
3641

3742
function updateμ!(r::LmResp{V}, linPr::V) where V<:FPVector
3843
n = length(linPr)

test/runtests.jl

+60
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,23 @@ end
502502
@test isapprox(vcov(gmsparse), vcov(gmdense))
503503
end
504504

505+
@testset "Sparse LM" begin
506+
rng = StableRNG(1)
507+
X = sprand(rng, 1000, 10, 0.01)
508+
β = randn(rng, 10)
509+
y = Bool[rand(rng) < logistic(x) for x in X * β]
510+
gmsparsev = [fit(LinearModel, X, y),
511+
fit(LinearModel, X, sparse(y)),
512+
fit(LinearModel, Matrix(X), sparse(y))]
513+
gmdense = fit(LinearModel, Matrix(X), y)
514+
515+
for gmsparse in gmsparsev
516+
@test isapprox(deviance(gmsparse), deviance(gmdense))
517+
@test isapprox(coef(gmsparse), coef(gmdense))
518+
@test isapprox(vcov(gmsparse), vcov(gmdense))
519+
end
520+
end
521+
505522
@testset "Predict" begin
506523
rng = StableRNG(123)
507524
X = rand(rng, 10, 2)
@@ -969,3 +986,46 @@ end
969986
secondcolinterceptmod = glm([randn(rng, 5) ones(5)], ones(5), Binomial(), LogitLink())
970987
@test hasintercept(secondcolinterceptmod)
971988
end
989+
990+
@testset "Issue #444. Views" begin
991+
X = randn(10, 2)
992+
y = X*ones(2) + randn(10)
993+
@test coef(glm(X, y, Normal(), IdentityLink())) ==
994+
coef(glm(view(X, 1:10, :), view(y, 1:10), Normal(), IdentityLink()))
995+
996+
x, y, w = rand(100, 2), rand(100), rand(100)
997+
lm1 = lm(x, y)
998+
lm2 = lm(x, view(y, :))
999+
lm3 = lm(view(x, :, :), y)
1000+
lm4 = lm(view(x, :, :), view(y, :))
1001+
@test coef(lm1) == coef(lm2) == coef(lm3) == coef(lm4)
1002+
1003+
lm5 = lm(x, y, wts=w)
1004+
lm6 = lm(x, view(y, :), wts=w)
1005+
lm7 = lm(view(x, :, :), y, wts=w)
1006+
lm8 = lm(view(x, :, :), view(y, :), wts=w)
1007+
lm9 = lm(x, y, wts=view(w, :))
1008+
lm10 = lm(x, view(y, :), wts=view(w, :))
1009+
lm11 = lm(view(x, :, :), y, wts=view(w, :))
1010+
lm12 = lm(view(x, :, :), view(y, :), wts=view(w, :))
1011+
@test coef(lm5) == coef(lm6) == coef(lm7) == coef(lm8) == coef(lm9) == coef(lm10) ==
1012+
coef(lm11) == coef(lm12)
1013+
1014+
x, y, w = rand(100, 2), rand(Bool, 100), rand(100)
1015+
glm1 = glm(x, y, Binomial())
1016+
glm2 = glm(x, view(y, :), Binomial())
1017+
glm3 = glm(view(x, :, :), y, Binomial())
1018+
glm4 = glm(view(x, :, :), view(y, :), Binomial())
1019+
@test coef(glm1) == coef(glm2) == coef(glm3) == coef(glm4)
1020+
1021+
glm5 = glm(x, y, Binomial(), wts=w)
1022+
glm6 = glm(x, view(y, :), Binomial(), wts=w)
1023+
glm7 = glm(view(x, :, :), y, Binomial(), wts=w)
1024+
glm8 = glm(view(x, :, :), view(y, :), Binomial(), wts=w)
1025+
glm9 = glm(x, y, Binomial(), wts=view(w, :))
1026+
glm10 = glm(x, view(y, :), Binomial(), wts=view(w, :))
1027+
glm11 = glm(view(x, :, :), y, Binomial(), wts=view(w, :))
1028+
glm12 = glm(view(x, :, :), view(y, :), Binomial(), wts=view(w, :))
1029+
@test coef(glm5) == coef(glm6) == coef(glm7) == coef(glm8) == coef(glm9) == coef(glm10) ==
1030+
coef(glm11) == coef(glm12)
1031+
end

0 commit comments

Comments
 (0)