@@ -25,3 +25,44 @@ mutable struct APIx1 <: Static end
2525 # update fallback = fit
2626 @test update (m0, 1 , 5 , nothing , randn (2 ), 5 ) == (5 , nothing , nothing )
2727end
28+
29+ struct DummyUnivariateFinite end
30+
31+ mutable struct UnivariateFiniteFitter <: Probabilistic end
32+
33+ @testset " models fitting a distribution to data" begin
34+
35+ function MLJModelInterface. fit (model:: UnivariateFiniteFitter ,
36+ verbosity:: Int , X, y)
37+
38+ fitresult = DummyUnivariateFinite ()
39+ report = nothing
40+ cache = nothing
41+
42+ verbosity > 0 && @info " Fitted a $fitresult "
43+
44+ return fitresult, cache, report
45+ end
46+
47+ MLJModelInterface. predict (model:: UnivariateFiniteFitter ,
48+ fitresult,
49+ X) = fill (fitresult, length (X))
50+
51+ MLJModelInterface. input_scitype (:: Type{<:UnivariateFiniteFitter} ) =
52+ AbstractVector{Nothing}
53+ MLJModelInterface. target_scitype (:: Type{<:UnivariateFiniteFitter} ) =
54+ AbstractVector{<: Finite }
55+
56+ y = categorical (collect (" aabbccaa" ))
57+ X = fill (nothing , length (y))
58+ model = UnivariateFiniteFitter ()
59+ fitresult, cache, report = MLJModelInterface. fit (model, 1 , X, y)
60+
61+ @test cache == nothing
62+ @test report == nothing
63+
64+ ytest = y[1 : 3 ]
65+ yhat = predict (model, fitresult, fill (nothing , 3 ))
66+ @test yhat == fill (DummyUnivariateFinite (), 3 )
67+
68+ end
0 commit comments