Skip to content

Commit 062e705

Browse files
anjali411facebook-github-bot
authored andcommitted
Add OpInfo tests for torch.{dot, vdot, bmm, mv} (pytorch#56409)
Summary: Pull Request resolved: pytorch#56409 Reviewed By: nikithamalgifb Differential Revision: D27870769 Pulled By: anjali411 fbshipit-source-id: a1a0e89856529a4739c7612c5b1e3c5ed2569126
1 parent e4faebc commit 062e705

File tree

2 files changed

+77
-10
lines changed

2 files changed

+77
-10
lines changed

test/test_autograd.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -5385,10 +5385,8 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
53855385
'expand', 'rot90', 'transpose',
53865386
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
53875387
'chunk', 'split', 'split_with_sizes', 'zero_',
5388-
'__radd__', 'sum', 'mul',
5389-
'__rmul__', 'dot', 'vdot', 'matmul',
5390-
'bmm', 'mv', 'ger', 'diagonal', 'fill_', 'sub',
5391-
'mean', 'inverse', 'linalg.tensorinv', 'matrix_exp',
5388+
'__radd__', 'mul', '__rmul__', 'matmul',
5389+
'diagonal', 'fill_', 'sub',
53925390
'narrow', 'swapaxes', 'swapdims', 'tensor_split',
53935391
'baddbmm'] + complex_list_filter + separate_complex_tests
53945392

torch/testing/_internal/common_methods_invocations.py

+75-6
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,36 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
658658
else:
659659
return (input, )
660660

661+
def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
662+
return (
663+
SampleInput(
664+
make_tensor((S, M, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
665+
args=(
666+
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
667+
)
668+
),
669+
)
670+
671+
def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs):
672+
return (
673+
SampleInput(
674+
make_tensor((M, S, M, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
675+
args=(
676+
make_tensor((M, M, S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
677+
)
678+
),
679+
)
680+
681+
def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs):
682+
return (
683+
SampleInput(
684+
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
685+
args=(
686+
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
687+
)
688+
),
689+
)
690+
661691
def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs):
662692
test_cases = (((S,), (S, M), (M,), 1, 1, False),
663693
((S,), (S, M), (M,), 0.2, 0.6, False),
@@ -3047,8 +3077,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
30473077
OpInfo('addbmm',
30483078
dtypes=floating_types(),
30493079
dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16),
3050-
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
3051-
*[torch.bfloat16] if CUDA11OrLater else []),
3080+
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
30523081
dtypesIfROCM=floating_types_and(torch.half),
30533082
skips=(
30543083
# addbmm does not correctly warn when resizing out= inputs
@@ -3061,6 +3090,50 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
30613090
SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.bfloat16, ),
30623091
device_type='cuda', active_if=not SM53OrLater)),
30633092
sample_inputs_func=sample_inputs_addbmm),
3093+
OpInfo('dot',
3094+
dtypes=all_types_and_complex_and(torch.float16),
3095+
dtypesIfCUDA=floating_and_complex_types_and(torch.float16),
3096+
skips=(
3097+
# dot does not handle correctly out= dtypes
3098+
# https://github.com/pytorch/pytorch/issues/55561
3099+
SkipInfo('TestCommon', 'test_out'),
3100+
),
3101+
assert_autodiffed=True,
3102+
sample_inputs_func=sample_inputs_dot_vdot),
3103+
OpInfo('vdot',
3104+
dtypes=all_types_and_complex_and(torch.float16),
3105+
dtypesIfCUDA=floating_and_complex_types_and(torch.float16),
3106+
skips=(
3107+
# vdot does not handle correctly out= dtypes
3108+
# https://github.com/pytorch/pytorch/issues/55561
3109+
SkipInfo('TestCommon', 'test_out'),
3110+
),
3111+
sample_inputs_func=sample_inputs_dot_vdot),
3112+
OpInfo('bmm',
3113+
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
3114+
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
3115+
assert_autodiffed=True,
3116+
skips=(
3117+
# bmm does not correctly warn when resizing out= inputs
3118+
SkipInfo('TestCommon', 'test_out'),
3119+
# cuda gradchecks are slow
3120+
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
3121+
SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
3122+
SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.bfloat16, ),
3123+
device_type='cuda', active_if=not SM53OrLater)),
3124+
sample_inputs_func=sample_inputs_bmm),
3125+
OpInfo('mv',
3126+
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
3127+
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
3128+
skips=(
3129+
# bmm does not correctly warn when resizing out= inputs
3130+
SkipInfo('TestCommon', 'test_out'),
3131+
SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.float16,)),
3132+
# mv calls into addmv which doesn't fully support float16
3133+
# RuntimeError: "addmv_impl_cpu" not implemented for 'Half'
3134+
SkipInfo('TestOpInfo', 'test_supported_dtypes', dtypes=(torch.float16,)),),
3135+
assert_autodiffed=True,
3136+
sample_inputs_func=sample_inputs_mv),
30643137
OpInfo('addr',
30653138
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
30663139
# Reference: https://github.com/pytorch/pytorch/issues/50747
@@ -5270,10 +5343,6 @@ def method_tests():
52705343
('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'),
52715344
('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), (), ident,
52725345
{'beta': 0.2, 'alpha': 0.6}),
5273-
('dot', (L,), ((L,),), '', (True,)),
5274-
('vdot', (L,), ((L,),),),
5275-
('bmm', (M, S, M), ((M, M, S),), '', (True,)),
5276-
('mv', (S, M), ((M,),), '', (True,)),
52775346
('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"),
52785347
('mvlgamma', torch.empty(S,).uniform_(1, 2), [2], "p=2"),
52795348
('mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3], "p=3"),

0 commit comments

Comments
 (0)