Skip to content

Commit 20d6ed1

Browse files
NicolasHugMateuszGuzekpmeier
authored andcommitted
[fbsync] Add filter parameters to list_models() (#7718)
Reviewed By: matteobettini Differential Revision: D48642263 fbshipit-source-id: 7dd986c91115b47383dfa69af070626a85b8bf07 Co-authored-by: Mateusz Guzek <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Philip Meier <[email protected]>
1 parent de45a8b commit 20d6ed1

File tree

2 files changed

+96
-11
lines changed

2 files changed

+96
-11
lines changed

test/test_extended_models.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,24 +103,84 @@ def test_weights_deserializable(name):
103103
assert pickle.loads(pickle.dumps(weights)) is weights
104104

105105

106+
def get_models_from_module(module):
107+
return [
108+
v.__name__
109+
for k, v in module.__dict__.items()
110+
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
111+
]
112+
113+
106114
@pytest.mark.parametrize(
107115
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
108116
)
109117
def test_list_models(module):
110-
def get_models_from_module(module):
111-
return [
112-
v.__name__
113-
for k, v in module.__dict__.items()
114-
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
115-
]
116-
117118
a = set(get_models_from_module(module))
118119
b = set(x.replace("quantized_", "") for x in models.list_models(module))
119120

120121
assert len(b) > 0
121122
assert a == b
122123

123124

125+
@pytest.mark.parametrize(
126+
"include_filters",
127+
[
128+
None,
129+
[],
130+
(),
131+
"",
132+
"*resnet*",
133+
["*alexnet*"],
134+
"*not-existing-model-for-test?",
135+
["*resnet*", "*alexnet*"],
136+
["*resnet*", "*alexnet*", "*not-existing-model-for-test?"],
137+
("*resnet*", "*alexnet*"),
138+
set(["*resnet*", "*alexnet*"]),
139+
],
140+
)
141+
@pytest.mark.parametrize(
142+
"exclude_filters",
143+
[
144+
None,
145+
[],
146+
(),
147+
"",
148+
"*resnet*",
149+
["*alexnet*"],
150+
["*not-existing-model-for-test?"],
151+
["resnet34", "*not-existing-model-for-test?"],
152+
["resnet34", "*resnet1*"],
153+
("resnet34", "*resnet1*"),
154+
set(["resnet34", "*resnet1*"]),
155+
],
156+
)
157+
def test_list_models_filters(include_filters, exclude_filters):
158+
actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters))
159+
classification_models = set(get_models_from_module(models))
160+
161+
if isinstance(include_filters, str):
162+
include_filters = [include_filters]
163+
if isinstance(exclude_filters, str):
164+
exclude_filters = [exclude_filters]
165+
166+
if include_filters:
167+
expected = set()
168+
for include_f in include_filters:
169+
include_f = include_f.strip("*?")
170+
expected = expected | set(x for x in classification_models if include_f in x)
171+
else:
172+
expected = classification_models
173+
174+
if exclude_filters:
175+
for exclude_f in exclude_filters:
176+
exclude_f = exclude_f.strip("*?")
177+
if exclude_f != "":
178+
a_exclude = set(x for x in classification_models if exclude_f in x)
179+
expected = expected - a_exclude
180+
181+
assert expected == actual
182+
183+
124184
@pytest.mark.parametrize(
125185
"name, weight",
126186
[

torchvision/models/_api.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import fnmatch
12
import importlib
23
import inspect
34
import sys
@@ -6,7 +7,7 @@
67
from functools import partial
78
from inspect import signature
89
from types import ModuleType
9-
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
10+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
1011

1112
from torch import nn
1213

@@ -203,19 +204,43 @@ def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
203204
return wrapper
204205

205206

206-
def list_models(module: Optional[ModuleType] = None) -> List[str]:
207+
def list_models(
208+
module: Optional[ModuleType] = None,
209+
include: Union[Iterable[str], str, None] = None,
210+
exclude: Union[Iterable[str], str, None] = None,
211+
) -> List[str]:
207212
"""
208213
Returns a list with the names of registered models.
209214
210215
Args:
211216
module (ModuleType, optional): The module from which we want to extract the available models.
217+
include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
218+
Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
219+
wildcards. In case of many filters, the results is the union of individual filters.
220+
exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
221+
Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
222+
wildcards. In case of many filters, the results is removal of all the models that match any individual filter.
212223
213224
Returns:
214225
models (list): A list with the names of available models.
215226
"""
216-
models = [
227+
all_models = {
217228
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
218-
]
229+
}
230+
if include:
231+
models: Set[str] = set()
232+
if isinstance(include, str):
233+
include = [include]
234+
for include_filter in include:
235+
models = models | set(fnmatch.filter(all_models, include_filter))
236+
else:
237+
models = all_models
238+
239+
if exclude:
240+
if isinstance(exclude, str):
241+
exclude = [exclude]
242+
for exclude_filter in exclude:
243+
models = models - set(fnmatch.filter(all_models, exclude_filter))
219244
return sorted(models)
220245

221246

0 commit comments

Comments
 (0)