Skip to content

Commit 20fe56b

Browse files
authored
Merge pull request #2217 from dsuess/2216_fix_script_on_features_fx
Fix jit.script breaking with features_fx
2 parents d4ef0b4 + 197c104 commit 20fe56b

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

tests/test_models.py

+32
Original file line numberDiff line numberDiff line change
@@ -631,3 +631,35 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
631631

632632
assert outputs.shape[0] == batch_size
633633
assert not torch.isnan(outputs).any(), 'Output included NaNs'
634+
635+
@pytest.mark.timeout(120)
636+
@pytest.mark.parametrize('model_name', ["regnetx_002"])
637+
@pytest.mark.parametrize('batch_size', [1])
638+
def test_model_forward_torchscript_with_features_fx(model_name, batch_size):
639+
"""Create a model with feature extraction based on fx, script it, and run
640+
a single forward pass"""
641+
if not has_fx_feature_extraction:
642+
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
643+
644+
allowed_models = list_models(
645+
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS,
646+
name_matches_cfg=True
647+
)
648+
assert model_name in allowed_models, f"{model_name=} not supported for this test"
649+
650+
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
651+
assert max(input_size) <= MAX_JIT_SIZE, "Fixed input size model > limit. Pick a different model to run this test"
652+
653+
with set_scriptable(True):
654+
model = create_model(model_name, pretrained=False, features_only=True, feature_cfg={"feature_cls": "fx"})
655+
model.eval()
656+
657+
model = torch.jit.script(model)
658+
with torch.no_grad():
659+
outputs = model(torch.randn((batch_size, *input_size)))
660+
661+
assert isinstance(outputs, list)
662+
663+
for tensor in outputs:
664+
assert tensor.shape[0] == batch_size
665+
assert not torch.isnan(tensor).any(), 'Output included NaNs'

timm/models/_features_fx.py

+4
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str
116116
class FeatureGraphNet(nn.Module):
117117
""" A FX Graph based feature extractor that works with the model feature_info metadata
118118
"""
119+
return_dict: torch.jit.Final[bool]
120+
119121
def __init__(
120122
self,
121123
model: nn.Module,
@@ -155,6 +157,8 @@ class GraphExtractNet(nn.Module):
155157
squeeze_out: if only one output, and output in list format, flatten to single tensor
156158
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
157159
"""
160+
return_dict: torch.jit.Final[bool]
161+
158162
def __init__(
159163
self,
160164
model: nn.Module,

0 commit comments

Comments
 (0)