@@ -7673,10 +7673,11 @@ class TestVarStd(TorchBaseTest):
76737673 @pytest .mark .parametrize (
76747674 "compute_unit, backend, frontend, torch_op, unbiased" ,
76757675 itertools .product (
7676- compute_units , backends , frontends , [torch . var , torch . std ], [True , False ]
7676+ compute_units , backends , frontends , [" var" , " std" ], [True , False ]
76777677 ),
76787678 )
76797679 def test_var_std_2_inputs (self , compute_unit , backend , frontend , torch_op , unbiased ):
7680+ torch_op = getattr (torch , torch_op )
76807681 model = ModuleWrapper (function = torch_op , kwargs = {"unbiased" : unbiased })
76817682 x = torch .randn (1 , 5 , 10 ) * 3
76827683 out = torch_op (x , unbiased = unbiased ).unsqueeze (0 )
@@ -7696,7 +7697,7 @@ def test_var_std_2_inputs(self, compute_unit, backend, frontend, torch_op, unbia
76967697 compute_units ,
76977698 backends ,
76987699 frontends ,
7699- [torch . var , torch . std ],
7700+ [" var" , " std" ],
77007701 [True , False ],
77017702 [[0 , 2 ], [1 ], [2 ]],
77027703 [True , False ],
@@ -7705,6 +7706,7 @@ def test_var_std_2_inputs(self, compute_unit, backend, frontend, torch_op, unbia
77057706 def test_var_std_4_inputs (
77067707 self , compute_unit , backend , frontend , torch_op , unbiased , dim , keepdim
77077708 ):
7709+ torch_op = getattr (torch , torch_op )
77087710 model = ModuleWrapper (
77097711 function = torch_op ,
77107712 kwargs = {"unbiased" : unbiased , "dim" : dim , "keepdim" : keepdim },
@@ -7720,7 +7722,7 @@ def test_var_std_4_inputs(
77207722 compute_units ,
77217723 backends ,
77227724 frontends ,
7723- [torch . var , torch . std ],
7725+ [" var" , " std" ],
77247726 [0 , 1 ],
77257727 [[0 , 2 ], [1 ], [2 ]],
77267728 [True , False ],
@@ -7729,6 +7731,7 @@ def test_var_std_4_inputs(
77297731 def test_var_std_with_correction (
77307732 self , compute_unit , backend , frontend , torch_op , correction , dim , keepdim
77317733 ):
7734+ torch_op = getattr (torch , torch_op )
77327735 model = ModuleWrapper (
77337736 function = torch_op ,
77347737 kwargs = {"correction" : correction , "dim" : dim , "keepdim" : keepdim },
@@ -9103,34 +9106,35 @@ def generate_tensor_rank_5(self, x):
91039106 backends ,
91049107 frontends ,
91059108 [
9106- torch . abs ,
9107- torch . acos ,
9108- torch . asin ,
9109- torch . atan ,
9110- torch . atanh ,
9111- torch . ceil ,
9112- torch . cos ,
9113- torch . cosh ,
9114- torch . exp ,
9115- torch . exp2 ,
9116- torch . floor ,
9117- torch . log ,
9118- torch . log2 ,
9119- torch . round ,
9120- torch . rsqrt ,
9121- torch . sign ,
9122- torch . sin ,
9123- torch . sinh ,
9124- torch . sqrt ,
9125- torch . square ,
9126- torch . tan ,
9127- torch . tanh ,
9109+ " abs" ,
9110+ " acos" ,
9111+ " asin" ,
9112+ " atan" ,
9113+ " atanh" ,
9114+ " ceil" ,
9115+ " cos" ,
9116+ " cosh" ,
9117+ " exp" ,
9118+ " exp2" ,
9119+ " floor" ,
9120+ " log" ,
9121+ " log2" ,
9122+ " round" ,
9123+ " rsqrt" ,
9124+ " sign" ,
9125+ " sin" ,
9126+ " sinh" ,
9127+ " sqrt" ,
9128+ " square" ,
9129+ " tan" ,
9130+ " tanh" ,
91289131 ],
91299132 ),
91309133 )
91319134 def test_torch_rank0_tensor (self , compute_unit , backend , frontend , torch_op ):
9132- if frontend == TorchFrontend .EXECUTORCH and torch_op == torch . exp2 :
9135+ if frontend == TorchFrontend .EXECUTORCH and torch_op == " exp2" :
91339136 pytest .skip ("torch._ops.aten.exp2.default is not Aten Canonical" )
9137+ torch_op = getattr (torch , torch_op )
91349138
91359139 class Model (nn .Module ):
91369140 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -12658,7 +12662,7 @@ class TestSTFT(TorchBaseTest):
1265812662 [16 ], # n_fft
1265912663 [None , 4 , 5 ], # hop_length
1266012664 [None , 16 , 9 ], # win_length
12661- [None , torch . hann_window ], # window
12665+ [None , " hann_window" ], # window
1266212666 [None , False , True ], # center
1266312667 ["constant" , "reflect" , "replicate" ], # pad mode
1266412668 [False , True ], # normalized
@@ -12668,6 +12672,8 @@ class TestSTFT(TorchBaseTest):
1266812672 def test_stft (self , compute_unit , backend , input_shape , complex , n_fft , hop_length , win_length , window , center , pad_mode , normalized , onesided ):
1266912673 if complex and onesided :
1267012674 pytest .skip ("Onesided stft not possible for complex inputs" )
12675+ if window is not None :
12676+ window = getattr (torch , window )
1267112677
1267212678 class STFTModel (torch .nn .Module ):
1267312679 def forward (self , x ):
0 commit comments