@@ -658,6 +658,36 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
658658 else :
659659 return (input , )
660660
661+ def sample_inputs_mv (self , device , dtype , requires_grad , ** kwargs ):
662+ return (
663+ SampleInput (
664+ make_tensor ((S , M , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
665+ args = (
666+ make_tensor ((M , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
667+ )
668+ ),
669+ )
670+
671+ def sample_inputs_bmm (self , device , dtype , requires_grad , ** kwargs ):
672+ return (
673+ SampleInput (
674+ make_tensor ((M , S , M , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
675+ args = (
676+ make_tensor ((M , M , S , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
677+ )
678+ ),
679+ )
680+
681+ def sample_inputs_dot_vdot (self , device , dtype , requires_grad , ** kwargs ):
682+ return (
683+ SampleInput (
684+ make_tensor ((S , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
685+ args = (
686+ make_tensor ((S , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
687+ )
688+ ),
689+ )
690+
661691def sample_inputs_addmv (op_info , device , dtype , requires_grad , ** kwargs ):
662692 test_cases = (((S ,), (S , M ), (M ,), 1 , 1 , False ),
663693 ((S ,), (S , M ), (M ,), 0.2 , 0.6 , False ),
@@ -3047,8 +3077,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
30473077 OpInfo ('addbmm' ,
30483078 dtypes = floating_types (),
30493079 dtypesIfCPU = all_types_and_complex_and (torch .float16 , torch .bfloat16 ),
3050- dtypesIfCUDA = floating_types_and (torch .float16 , torch .complex64 , torch .complex128 ,
3051- * [torch .bfloat16 ] if CUDA11OrLater else []),
3080+ dtypesIfCUDA = floating_and_complex_types_and (torch .float16 , * [torch .bfloat16 ] if CUDA11OrLater else []),
30523081 dtypesIfROCM = floating_types_and (torch .half ),
30533082 skips = (
30543083 # addbmm does not correctly warn when resizing out= inputs
@@ -3061,6 +3090,50 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
30613090 SkipInfo ('TestOpInfo' , 'test_supported_backward' , dtypes = (torch .bfloat16 , ),
30623091 device_type = 'cuda' , active_if = not SM53OrLater )),
30633092 sample_inputs_func = sample_inputs_addbmm ),
3093+ OpInfo ('dot' ,
3094+ dtypes = all_types_and_complex_and (torch .float16 ),
3095+ dtypesIfCUDA = floating_and_complex_types_and (torch .float16 ),
3096+ skips = (
3097+ # dot does not handle correctly out= dtypes
3098+ # https://github.com/pytorch/pytorch/issues/55561
3099+ SkipInfo ('TestCommon' , 'test_out' ),
3100+ ),
3101+ assert_autodiffed = True ,
3102+ sample_inputs_func = sample_inputs_dot_vdot ),
3103+ OpInfo ('vdot' ,
3104+ dtypes = all_types_and_complex_and (torch .float16 ),
3105+ dtypesIfCUDA = floating_and_complex_types_and (torch .float16 ),
3106+ skips = (
3107+ # vdot does not handle correctly out= dtypes
3108+ # https://github.com/pytorch/pytorch/issues/55561
3109+ SkipInfo ('TestCommon' , 'test_out' ),
3110+ ),
3111+ sample_inputs_func = sample_inputs_dot_vdot ),
3112+ OpInfo ('bmm' ,
3113+ dtypes = all_types_and_complex_and (torch .bfloat16 , torch .float16 ),
3114+ dtypesIfCUDA = floating_and_complex_types_and (torch .float16 , * [torch .bfloat16 ] if CUDA11OrLater else []),
3115+ assert_autodiffed = True ,
3116+ skips = (
3117+ # bmm does not correctly warn when resizing out= inputs
3118+ SkipInfo ('TestCommon' , 'test_out' ),
3119+ # cuda gradchecks are slow
3120+ # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
3121+ SkipInfo ('TestGradients' , 'test_fn_gradgrad' , device_type = 'cuda' ),
3122+ SkipInfo ('TestOpInfo' , 'test_supported_backward' , dtypes = (torch .bfloat16 , ),
3123+ device_type = 'cuda' , active_if = not SM53OrLater )),
3124+ sample_inputs_func = sample_inputs_bmm ),
3125+ OpInfo ('mv' ,
3126+ dtypes = all_types_and_complex_and (torch .float16 , torch .bfloat16 ),
3127+ dtypesIfCUDA = floating_and_complex_types_and (torch .float16 , * [torch .bfloat16 ] if CUDA11OrLater else []),
3128+ skips = (
3129+ # bmm does not correctly warn when resizing out= inputs
3130+ SkipInfo ('TestCommon' , 'test_out' ),
3131+ SkipInfo ('TestOpInfo' , 'test_supported_backward' , dtypes = (torch .float16 ,)),
3132+ # mv calls into addmv which doesn't fully support float16
3133+ # RuntimeError: "addmv_impl_cpu" not implemented for 'Half'
3134+ SkipInfo ('TestOpInfo' , 'test_supported_dtypes' , dtypes = (torch .float16 ,)),),
3135+ assert_autodiffed = True ,
3136+ sample_inputs_func = sample_inputs_mv ),
30643137 OpInfo ('addr' ,
30653138 dtypes = all_types_and_complex_and (torch .bool , torch .bfloat16 , torch .float16 ),
30663139 # Reference: https://github.com/pytorch/pytorch/issues/50747
@@ -5270,10 +5343,6 @@ def method_tests():
52705343 ('baddbmm' , (), ((S , S , S ), (S , S , M )), 'scalar_broadcast_lhs' ),
52715344 ('baddbmm' , (), ((S , S , S ), (S , S , M )), 'scalar_broadcast_lhs_coef' , (), (), (), ident ,
52725345 {'beta' : 0.2 , 'alpha' : 0.6 }),
5273- ('dot' , (L ,), ((L ,),), '' , (True ,)),
5274- ('vdot' , (L ,), ((L ,),),),
5275- ('bmm' , (M , S , M ), ((M , M , S ),), '' , (True ,)),
5276- ('mv' , (S , M ), ((M ,),), '' , (True ,)),
52775346 ('mvlgamma' , torch .empty (S ,).uniform_ (0.5 , 1 ), [1 ], "p=1" ),
52785347 ('mvlgamma' , torch .empty (S ,).uniform_ (1 , 2 ), [2 ], "p=2" ),
52795348 ('mvlgamma' , torch .empty (S , S ).uniform_ (1.5 , 3 ), [3 ], "p=3" ),
0 commit comments