@@ -631,3 +631,35 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
631631
632632 assert outputs .shape [0 ] == batch_size
633633 assert not torch .isnan (outputs ).any (), 'Output included NaNs'
634+
635+ @pytest .mark .timeout (120 )
636+ @pytest .mark .parametrize ('model_name' , ["regnetx_002" ])
637+ @pytest .mark .parametrize ('batch_size' , [1 ])
638+ def test_model_forward_torchscript_with_features_fx (model_name , batch_size ):
639+ """Create a model with feature extraction based on fx, script it, and run
640+ a single forward pass"""
641+ if not has_fx_feature_extraction :
642+ pytest .skip ("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required." )
643+
644+ allowed_models = list_models (
645+ exclude_filters = EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS ,
646+ name_matches_cfg = True
647+ )
648+ assert model_name in allowed_models , f"{ model_name = } not supported for this test"
649+
650+ input_size = _get_input_size (model_name = model_name , target = TARGET_JIT_SIZE )
651+ assert max (input_size ) <= MAX_JIT_SIZE , "Fixed input size model > limit. Pick a different model to run this test"
652+
653+ with set_scriptable (True ):
654+ model = create_model (model_name , pretrained = False , features_only = True , feature_cfg = {"feature_cls" : "fx" })
655+ model .eval ()
656+
657+ model = torch .jit .script (model )
658+ with torch .no_grad ():
659+ outputs = model (torch .randn ((batch_size , * input_size )))
660+
661+ assert isinstance (outputs , list )
662+
663+ for tensor in outputs :
664+ assert tensor .shape [0 ] == batch_size
665+ assert not torch .isnan (tensor ).any (), 'Output included NaNs'
0 commit comments