-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathregressionmodel.jl
36 lines (26 loc) · 994 Bytes
/
regressionmodel.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
module TestRegressionModel
using Test, LinearAlgebra, StatsAPI
using StatsAPI: RegressionModel, crossmodelmatrix
struct MyRegressionModel <: RegressionModel
end
struct MyWeightedRegressionModel <: RegressionModel
wts::AbstractVector
end
StatsAPI.modelmatrix(::MyRegressionModel) = [1 2; 3 4]
function StatsAPI.modelmatrix(r::MyWeightedRegressionModel; weighted::Bool=false)
X = [1 2; 3 4]
weighted ? sqrt.(r.wts).*X : X
end
w = [0.3, 0.2]
@testset "TestRegressionModel" begin
m = MyRegressionModel()
r = MyWeightedRegressionModel(w)
@test crossmodelmatrix(m) == [10 14; 14 20]
@test crossmodelmatrix(m; weighted=false) == [10 14; 14 20]
@test crossmodelmatrix(m) isa Symmetric
@test crossmodelmatrix(r) == [10 14; 14 20]
@test crossmodelmatrix(r; weighted=false) == [10 14; 14 20]
@test crossmodelmatrix(r; weighted=true) ≈ [2.1 3.0; 3.0 4.4]
@test crossmodelmatrix(r; weighted=true) isa Symmetric
end
end # module TestRegressionModel