File tree Expand file tree Collapse file tree 4 files changed +29
-3
lines changed Expand file tree Collapse file tree 4 files changed +29
-3
lines changed Original file line number Diff line number Diff line change 1
1
name = " MLJModelInterface"
2
2
uuid = " e80e1ace-859a-464e-9ed9-23947d8ae3ea"
3
3
authors = [" Thibaut Lienart and Anthony Blaom" ]
4
- version = " 1.10 .0"
4
+ version = " 1.11 .0"
5
5
6
6
[deps ]
7
7
Random = " 9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -19,7 +19,7 @@ OrderedCollections = "1"
19
19
Random = " <0.0.1, 1"
20
20
ScientificTypes = " 3"
21
21
ScientificTypesBase = " 3"
22
- StatisticalTraits = " 3.3 "
22
+ StatisticalTraits = " 3.4 "
23
23
Tables = " 1"
24
24
Test = " <0.0.1, 1"
25
25
julia = " 1.6"
Original file line number Diff line number Diff line change @@ -8,6 +8,7 @@ const MODEL_TRAITS = [
8
8
:predict_scitype ,
9
9
:transform_scitype ,
10
10
:inverse_transform_scitype ,
11
+ :target_in_fit ,
11
12
:is_pure_julia ,
12
13
:package_name ,
13
14
:package_license ,
Original file line number Diff line number Diff line change 32
32
StatTraits. is_supervised (:: Type{<:Supervised} ) = true
33
33
StatTraits. is_supervised (:: Type{<:SupervisedAnnotator} ) = true
34
34
35
+ StatTraits. target_in_fit (:: Type{<:Supervised} ) = true
36
+ StatTraits. target_in_fit (:: Type{<:Unsupervised} ) = false
37
+
35
38
StatTraits. prediction_type (:: Type{<:Deterministic} ) = :deterministic
36
39
StatTraits. prediction_type (:: Type{<:Probabilistic} ) = :probabilistic
37
40
StatTraits. prediction_type (:: Type{<:Interval} ) = :interval
@@ -73,7 +76,15 @@ function supervised_fit_data_scitype(M)
73
76
return ret
74
77
end
75
78
76
- StatTraits. fit_data_scitype (M:: Type{<:Unsupervised} ) = Tuple{input_scitype (M)}
79
+ # helper to determine the scitype of unsupervised models
80
+ function unsupervised_fit_data_scitype (M)
81
+ I = input_scitype (M)
82
+ T = target_scitype (M)
83
+ target_in_fit (M) && return Tuple{I, T}
84
+ return Tuple{I}
85
+ end
86
+
87
+ StatTraits. fit_data_scitype (M:: Type{<:Unsupervised} ) = unsupervised_fit_data_scitype (M)
77
88
StatTraits. fit_data_scitype (:: Type{<:Static} ) = Tuple{}
78
89
StatTraits. fit_data_scitype (M:: Type{<:Supervised} ) = supervised_fit_data_scitype (M)
79
90
Original file line number Diff line number Diff line change 23
23
@mlj_model mutable struct UA <: UnsupervisedAnnotator
24
24
end
25
25
26
+ @mlj_model mutable struct SupervisedTransformer <: Unsupervised
27
+ end
28
+
29
+
26
30
foo (:: P1 ) = 0
27
31
bar (:: P1 ) = nothing
28
32
@@ -34,6 +38,10 @@ M.package_name(::Type{<:U1}) = "Bach"
34
38
M. package_url (:: Type{<:U1} ) = " www.did_he_write_565.com"
35
39
M. human_name (:: Type{<:U1} ) = " funky model"
36
40
41
+ M. target_in_fit (:: Type{<:SupervisedTransformer} ) = true
42
+ M. target_scitype (:: Type{<:SupervisedTransformer} ) = Continuous
43
+ M. input_scitype (:: Type{<:SupervisedTransformer} ) = Finite
44
+
37
45
@testset " traits" begin
38
46
ms = S1 ()
39
47
mu = U1 (a= 42 , b= sin)
@@ -42,6 +50,7 @@ M.human_name(::Type{<:U1}) = "funky model"
42
50
mi = I1 ()
43
51
sa = SA ()
44
52
ua = UA ()
53
+ supervised_transformer = SupervisedTransformer ()
45
54
46
55
@test input_scitype (ms) == Unknown
47
56
@test output_scitype (ms) == Unknown
@@ -115,6 +124,11 @@ M.human_name(::Type{<:U1}) = "funky model"
115
124
setfull ()
116
125
117
126
@test Set (implemented_methods (mp)) == Set ([:clean! ,:bar ,:foo ])
127
+
128
+ @test fit_data_scitype (mu) == Tuple{Unknown};;;
129
+ @test fit_data_scitype (mu) == Tuple{Unknown}
130
+ @test fit_data_scitype (supervised_transformer) == Tuple{Finite,Continuous}
131
+
118
132
end
119
133
120
134
@testset " `_density` - helper for predict_scitype fallback" begin
You can’t perform that action at this time.
0 commit comments