@@ -658,6 +658,36 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
658
658
else :
659
659
return (input , )
660
660
661
+ def sample_inputs_mv (self , device , dtype , requires_grad , ** kwargs ):
662
+ return (
663
+ SampleInput (
664
+ make_tensor ((S , M , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
665
+ args = (
666
+ make_tensor ((M , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
667
+ )
668
+ ),
669
+ )
670
+
671
+ def sample_inputs_bmm (self , device , dtype , requires_grad , ** kwargs ):
672
+ return (
673
+ SampleInput (
674
+ make_tensor ((M , S , M , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
675
+ args = (
676
+ make_tensor ((M , M , S , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
677
+ )
678
+ ),
679
+ )
680
+
681
+ def sample_inputs_dot_vdot (self , device , dtype , requires_grad , ** kwargs ):
682
+ return (
683
+ SampleInput (
684
+ make_tensor ((S , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
685
+ args = (
686
+ make_tensor ((S , ), device , dtype , low = None , high = None , requires_grad = requires_grad ),
687
+ )
688
+ ),
689
+ )
690
+
661
691
def sample_inputs_addmv (op_info , device , dtype , requires_grad , ** kwargs ):
662
692
test_cases = (((S ,), (S , M ), (M ,), 1 , 1 , False ),
663
693
((S ,), (S , M ), (M ,), 0.2 , 0.6 , False ),
@@ -3047,8 +3077,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
3047
3077
OpInfo ('addbmm' ,
3048
3078
dtypes = floating_types (),
3049
3079
dtypesIfCPU = all_types_and_complex_and (torch .float16 , torch .bfloat16 ),
3050
- dtypesIfCUDA = floating_types_and (torch .float16 , torch .complex64 , torch .complex128 ,
3051
- * [torch .bfloat16 ] if CUDA11OrLater else []),
3080
+ dtypesIfCUDA = floating_and_complex_types_and (torch .float16 , * [torch .bfloat16 ] if CUDA11OrLater else []),
3052
3081
dtypesIfROCM = floating_types_and (torch .half ),
3053
3082
skips = (
3054
3083
# addbmm does not correctly warn when resizing out= inputs
@@ -3061,6 +3090,50 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
3061
3090
SkipInfo ('TestOpInfo' , 'test_supported_backward' , dtypes = (torch .bfloat16 , ),
3062
3091
device_type = 'cuda' , active_if = not SM53OrLater )),
3063
3092
sample_inputs_func = sample_inputs_addbmm ),
3093
+ OpInfo ('dot' ,
3094
+ dtypes = all_types_and_complex_and (torch .float16 ),
3095
+ dtypesIfCUDA = floating_and_complex_types_and (torch .float16 ),
3096
+ skips = (
3097
+ # dot does not handle correctly out= dtypes
3098
+ # https://github.com/pytorch/pytorch/issues/55561
3099
+ SkipInfo ('TestCommon' , 'test_out' ),
3100
+ ),
3101
+ assert_autodiffed = True ,
3102
+ sample_inputs_func = sample_inputs_dot_vdot ),
3103
+ OpInfo ('vdot' ,
3104
+ dtypes = all_types_and_complex_and (torch .float16 ),
3105
+ dtypesIfCUDA = floating_and_complex_types_and (torch .float16 ),
3106
+ skips = (
3107
+ # vdot does not handle correctly out= dtypes
3108
+ # https://github.com/pytorch/pytorch/issues/55561
3109
+ SkipInfo ('TestCommon' , 'test_out' ),
3110
+ ),
3111
+ sample_inputs_func = sample_inputs_dot_vdot ),
3112
+ OpInfo ('bmm' ,
3113
+ dtypes = all_types_and_complex_and (torch .bfloat16 , torch .float16 ),
3114
+ dtypesIfCUDA = floating_and_complex_types_and (torch .float16 , * [torch .bfloat16 ] if CUDA11OrLater else []),
3115
+ assert_autodiffed = True ,
3116
+ skips = (
3117
+ # bmm does not correctly warn when resizing out= inputs
3118
+ SkipInfo ('TestCommon' , 'test_out' ),
3119
+ # cuda gradchecks are slow
3120
+ # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
3121
+ SkipInfo ('TestGradients' , 'test_fn_gradgrad' , device_type = 'cuda' ),
3122
+ SkipInfo ('TestOpInfo' , 'test_supported_backward' , dtypes = (torch .bfloat16 , ),
3123
+ device_type = 'cuda' , active_if = not SM53OrLater )),
3124
+ sample_inputs_func = sample_inputs_bmm ),
3125
+ OpInfo ('mv' ,
3126
+ dtypes = all_types_and_complex_and (torch .float16 , torch .bfloat16 ),
3127
+ dtypesIfCUDA = floating_and_complex_types_and (torch .float16 , * [torch .bfloat16 ] if CUDA11OrLater else []),
3128
+ skips = (
3129
+ # bmm does not correctly warn when resizing out= inputs
3130
+ SkipInfo ('TestCommon' , 'test_out' ),
3131
+ SkipInfo ('TestOpInfo' , 'test_supported_backward' , dtypes = (torch .float16 ,)),
3132
+ # mv calls into addmv which doesn't fully support float16
3133
+ # RuntimeError: "addmv_impl_cpu" not implemented for 'Half'
3134
+ SkipInfo ('TestOpInfo' , 'test_supported_dtypes' , dtypes = (torch .float16 ,)),),
3135
+ assert_autodiffed = True ,
3136
+ sample_inputs_func = sample_inputs_mv ),
3064
3137
OpInfo ('addr' ,
3065
3138
dtypes = all_types_and_complex_and (torch .bool , torch .bfloat16 , torch .float16 ),
3066
3139
# Reference: https://github.com/pytorch/pytorch/issues/50747
@@ -5270,10 +5343,6 @@ def method_tests():
5270
5343
('baddbmm' , (), ((S , S , S ), (S , S , M )), 'scalar_broadcast_lhs' ),
5271
5344
('baddbmm' , (), ((S , S , S ), (S , S , M )), 'scalar_broadcast_lhs_coef' , (), (), (), ident ,
5272
5345
{'beta' : 0.2 , 'alpha' : 0.6 }),
5273
- ('dot' , (L ,), ((L ,),), '' , (True ,)),
5274
- ('vdot' , (L ,), ((L ,),),),
5275
- ('bmm' , (M , S , M ), ((M , M , S ),), '' , (True ,)),
5276
- ('mv' , (S , M ), ((M ,),), '' , (True ,)),
5277
5346
('mvlgamma' , torch .empty (S ,).uniform_ (0.5 , 1 ), [1 ], "p=1" ),
5278
5347
('mvlgamma' , torch .empty (S ,).uniform_ (1 , 2 ), [2 ], "p=2" ),
5279
5348
('mvlgamma' , torch .empty (S , S ).uniform_ (1.5 , 3 ), [3 ], "p=3" ),
0 commit comments