Skip to content

Commit 8107f2c

Browse files
authored
Merge pull request #205 from JuliaAI/dev
For a 1.11.0 release
2 parents 08414af + e3d2571 commit 8107f2c

File tree

4 files changed

+29
-3
lines changed

4 files changed

+29
-3
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "1.10.0"
4+
version = "1.11.0"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -19,7 +19,7 @@ OrderedCollections = "1"
1919
Random = "<0.0.1, 1"
2020
ScientificTypes = "3"
2121
ScientificTypesBase = "3"
22-
StatisticalTraits = "3.3"
22+
StatisticalTraits = "3.4"
2323
Tables = "1"
2424
Test = "<0.0.1, 1"
2525
julia = "1.6"

src/MLJModelInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ const MODEL_TRAITS = [
88
:predict_scitype,
99
:transform_scitype,
1010
:inverse_transform_scitype,
11+
:target_in_fit,
1112
:is_pure_julia,
1213
:package_name,
1314
:package_license,

src/model_traits.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ end
3232
StatTraits.is_supervised(::Type{<:Supervised}) = true
3333
StatTraits.is_supervised(::Type{<:SupervisedAnnotator}) = true
3434

35+
StatTraits.target_in_fit(::Type{<:Supervised}) = true
36+
StatTraits.target_in_fit(::Type{<:Unsupervised}) = false
37+
3538
StatTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
3639
StatTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
3740
StatTraits.prediction_type(::Type{<:Interval}) = :interval
@@ -73,7 +76,15 @@ function supervised_fit_data_scitype(M)
7376
return ret
7477
end
7578

76-
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) = Tuple{input_scitype(M)}
79+
# helper to determine the scitype of unsupervised models
80+
function unsupervised_fit_data_scitype(M)
81+
I = input_scitype(M)
82+
T = target_scitype(M)
83+
target_in_fit(M) && return Tuple{I, T}
84+
return Tuple{I}
85+
end
86+
87+
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) = unsupervised_fit_data_scitype(M)
7788
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
7889
StatTraits.fit_data_scitype(M::Type{<:Supervised}) = supervised_fit_data_scitype(M)
7990

test/model_traits.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ end
2323
@mlj_model mutable struct UA <: UnsupervisedAnnotator
2424
end
2525

26+
@mlj_model mutable struct SupervisedTransformer <: Unsupervised
27+
end
28+
29+
2630
foo(::P1) = 0
2731
bar(::P1) = nothing
2832

@@ -34,6 +38,10 @@ M.package_name(::Type{<:U1}) = "Bach"
3438
M.package_url(::Type{<:U1}) = "www.did_he_write_565.com"
3539
M.human_name(::Type{<:U1}) = "funky model"
3640

41+
M.target_in_fit(::Type{<:SupervisedTransformer}) = true
42+
M.target_scitype(::Type{<:SupervisedTransformer}) = Continuous
43+
M.input_scitype(::Type{<:SupervisedTransformer}) = Finite
44+
3745
@testset "traits" begin
3846
ms = S1()
3947
mu = U1(a=42, b=sin)
@@ -42,6 +50,7 @@ M.human_name(::Type{<:U1}) = "funky model"
4250
mi = I1()
4351
sa = SA()
4452
ua = UA()
53+
supervised_transformer = SupervisedTransformer()
4554

4655
@test input_scitype(ms) == Unknown
4756
@test output_scitype(ms) == Unknown
@@ -115,6 +124,11 @@ M.human_name(::Type{<:U1}) = "funky model"
115124
setfull()
116125

117126
@test Set(implemented_methods(mp)) == Set([:clean!,:bar,:foo])
127+
128+
@test fit_data_scitype(mu) == Tuple{Unknown};;;
129+
@test fit_data_scitype(mu) == Tuple{Unknown}
130+
@test fit_data_scitype(supervised_transformer) == Tuple{Finite,Continuous}
131+
118132
end
119133

120134
@testset "`_density` - helper for predict_scitype fallback" begin

0 commit comments

Comments
 (0)