Skip to content

Commit 0b713d2

Browse files
authored
Fix torch.addmm and add unit tests (#2122)
1 parent c3f1445 commit 0b713d2

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -921,34 +921,29 @@ def cumsum(context, node):
921921

922922
@register_torch_op
923923
def addmm(context, node):
924-
# addmm(Tensor input, Tensor mat1, Tensor mat2, Scalar beta=1, Scalar alpha=1)
925-
# output = beta * input + alpha * mat1 * mat2
924+
# addmm(Tensor x, Tensor mat1, Tensor mat2, Scalar beta=1, Scalar alpha=1)
925+
# output = beta * x + alpha * (mat1 @ mat2)
926926

927927
assert len(node.outputs) == 1
928928
inputs = _get_inputs(context, node, expected=[3, 4, 5])
929-
bias = inputs[0]
929+
x = inputs[0]
930930
mat1 = inputs[1]
931931
mat2 = inputs[2]
932932
beta = inputs[3] if len(inputs) > 3 else mb.const(val=1.0)
933933
alpha = inputs[4] if len(inputs) > 4 else mb.const(val=1.0)
934934

935935
if beta.val != 1.0:
936-
# Apply scaling factor beta to the bias.
937-
bias = mb.mul(x=beta, y=bias, name=bias.name + "_scaled")
938-
context.add(bias)
936+
# Apply beta scaling factor to the input.
937+
x = mb.mul(x=x, y=beta)
939938

940-
if alpha.val != 1.0:
941-
# Apply scaling factor alpha to the input.
942-
mat1 = mb.mul(x=alpha, y=mat1, name=mat1.name + "_scaled")
943-
context.add(mat1)
939+
matmul = mb.matmul(x=mat1, y=mat2)
944940

945-
# MIL linear will transpose mat2, but addmm expects that mat1 and mat2
946-
# can multiply as is. So we add a transpose.
947-
mat2 = mb.transpose(x=mat2, perm=[1, 0], name=mat2.name + "_transposed")
948-
context.add(mat2)
941+
if alpha.val != 1.0:
942+
# Apply alpha scaling factor to the matrix multiplicaiton
943+
matmul = mb.mul(x=alpha, y=matmul)
949944

950-
addmm_node = mb.linear(x=mat1, weight=mat2, bias=bias, name=node.name)
951-
context.add(addmm_node)
945+
result = mb.add(x=x, y=matmul, name=node.name)
946+
context.add(result)
952947

953948

954949
@register_torch_op

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8505,6 +8505,37 @@ def forward(self, x, y, z):
85058505
)
85068506

85078507

8508+
class TestAddmm(TorchBaseTest):
8509+
@pytest.mark.parametrize(
8510+
"compute_unit, backend, shapes, beta, alpha",
8511+
itertools.product(
8512+
compute_units,
8513+
backends,
8514+
((2, 2, 2), (4, 5, 9)),
8515+
(1., 2.),
8516+
(1., 3.),
8517+
)
8518+
)
8519+
def test_addmm(self, compute_unit, backend, shapes, beta, alpha):
8520+
8521+
class TestModel(nn.Module):
8522+
def forward(self, x):
8523+
return torch.addmm(x, m1, m2, beta=beta, alpha=alpha)
8524+
8525+
8526+
m, n, p = shapes
8527+
8528+
# m1 @ m2 must be legal
8529+
m1 = torch.randn(m, n)
8530+
m2 = torch.randn(n, p)
8531+
# x must be the same shape as m1 @ m2
8532+
x_shape = (m, p)
8533+
8534+
self.run_compare_torch(
8535+
x_shape, TestModel(), backend=backend, compute_unit=compute_unit,
8536+
)
8537+
8538+
85088539
class TestScatter(TorchBaseTest):
85098540
@pytest.mark.parametrize(
85108541
"compute_unit, backend, shapes_dims, minimum_deployment_target",

0 commit comments

Comments
 (0)