@@ -103,24 +103,84 @@ def test_weights_deserializable(name):
103
103
assert pickle .loads (pickle .dumps (weights )) is weights
104
104
105
105
106
+ def get_models_from_module (module ):
107
+ return [
108
+ v .__name__
109
+ for k , v in module .__dict__ .items ()
110
+ if callable (v ) and k [0 ].islower () and k [0 ] != "_" and k not in models ._api .__all__
111
+ ]
112
+
113
+
106
114
@pytest .mark .parametrize (
107
115
"module" , [models , models .detection , models .quantization , models .segmentation , models .video , models .optical_flow ]
108
116
)
109
117
def test_list_models (module ):
110
- def get_models_from_module (module ):
111
- return [
112
- v .__name__
113
- for k , v in module .__dict__ .items ()
114
- if callable (v ) and k [0 ].islower () and k [0 ] != "_" and k not in models ._api .__all__
115
- ]
116
-
117
118
a = set (get_models_from_module (module ))
118
119
b = set (x .replace ("quantized_" , "" ) for x in models .list_models (module ))
119
120
120
121
assert len (b ) > 0
121
122
assert a == b
122
123
123
124
125
+ @pytest .mark .parametrize (
126
+ "include_filters" ,
127
+ [
128
+ None ,
129
+ [],
130
+ (),
131
+ "" ,
132
+ "*resnet*" ,
133
+ ["*alexnet*" ],
134
+ "*not-existing-model-for-test?" ,
135
+ ["*resnet*" , "*alexnet*" ],
136
+ ["*resnet*" , "*alexnet*" , "*not-existing-model-for-test?" ],
137
+ ("*resnet*" , "*alexnet*" ),
138
+ set (["*resnet*" , "*alexnet*" ]),
139
+ ],
140
+ )
141
+ @pytest .mark .parametrize (
142
+ "exclude_filters" ,
143
+ [
144
+ None ,
145
+ [],
146
+ (),
147
+ "" ,
148
+ "*resnet*" ,
149
+ ["*alexnet*" ],
150
+ ["*not-existing-model-for-test?" ],
151
+ ["resnet34" , "*not-existing-model-for-test?" ],
152
+ ["resnet34" , "*resnet1*" ],
153
+ ("resnet34" , "*resnet1*" ),
154
+ set (["resnet34" , "*resnet1*" ]),
155
+ ],
156
+ )
157
+ def test_list_models_filters (include_filters , exclude_filters ):
158
+ actual = set (models .list_models (models , include = include_filters , exclude = exclude_filters ))
159
+ classification_models = set (get_models_from_module (models ))
160
+
161
+ if isinstance (include_filters , str ):
162
+ include_filters = [include_filters ]
163
+ if isinstance (exclude_filters , str ):
164
+ exclude_filters = [exclude_filters ]
165
+
166
+ if include_filters :
167
+ expected = set ()
168
+ for include_f in include_filters :
169
+ include_f = include_f .strip ("*?" )
170
+ expected = expected | set (x for x in classification_models if include_f in x )
171
+ else :
172
+ expected = classification_models
173
+
174
+ if exclude_filters :
175
+ for exclude_f in exclude_filters :
176
+ exclude_f = exclude_f .strip ("*?" )
177
+ if exclude_f != "" :
178
+ a_exclude = set (x for x in classification_models if exclude_f in x )
179
+ expected = expected - a_exclude
180
+
181
+ assert expected == actual
182
+
183
+
124
184
@pytest .mark .parametrize (
125
185
"name, weight" ,
126
186
[
0 commit comments