@@ -631,3 +631,35 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
631
631
632
632
assert outputs .shape [0 ] == batch_size
633
633
assert not torch .isnan (outputs ).any (), 'Output included NaNs'
634
+
635
+ @pytest .mark .timeout (120 )
636
+ @pytest .mark .parametrize ('model_name' , ["regnetx_002" ])
637
+ @pytest .mark .parametrize ('batch_size' , [1 ])
638
+ def test_model_forward_torchscript_with_features_fx (model_name , batch_size ):
639
+ """Create a model with feature extraction based on fx, script it, and run
640
+ a single forward pass"""
641
+ if not has_fx_feature_extraction :
642
+ pytest .skip ("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required." )
643
+
644
+ allowed_models = list_models (
645
+ exclude_filters = EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS ,
646
+ name_matches_cfg = True
647
+ )
648
+ assert model_name in allowed_models , f"{ model_name = } not supported for this test"
649
+
650
+ input_size = _get_input_size (model_name = model_name , target = TARGET_JIT_SIZE )
651
+ assert max (input_size ) <= MAX_JIT_SIZE , "Fixed input size model > limit. Pick a different model to run this test"
652
+
653
+ with set_scriptable (True ):
654
+ model = create_model (model_name , pretrained = False , features_only = True , feature_cfg = {"feature_cls" : "fx" })
655
+ model .eval ()
656
+
657
+ model = torch .jit .script (model )
658
+ with torch .no_grad ():
659
+ outputs = model (torch .randn ((batch_size , * input_size )))
660
+
661
+ assert isinstance (outputs , list )
662
+
663
+ for tensor in outputs :
664
+ assert tensor .shape [0 ] == batch_size
665
+ assert not torch .isnan (tensor ).any (), 'Output included NaNs'
0 commit comments