Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,31 @@
custom_code_at_the_beginning: |
return dipu_add__tensor(self, other, -alpha);

- schema: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
custom_code_at_the_beginning: |
at::Tensor out = UnaryOpInferrer().infer_out(self);
interface: diopiSubScalar(ctx, out, self, other, alpha)

- schema: "sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
dummy_call_diopi: True
ins: [selfTmp]
custom_code_at_the_beginning: |
at::native::sub_check(self, other);
auto out = BinaryOpInferrer().infer_out(self, other);
return dipu_add_out(self, other, -alpha, out);

if (is_scalar_on_cpu(other)) {
at::native::alpha_check(self.scalar_type(), alpha);
return dipu_sub_scalar(self, other.item(), alpha);
}

at::Tensor selfTmp = self;
if (is_scalar_on_cpu(selfTmp)) {
selfTmp = nodispatch::empty({1}, self.options().device(other.device()));
dipu_fill__scalar(selfTmp, self.item());
}

at::native::alpha_check(selfTmp.scalar_type(), alpha);
at::Tensor out = BinaryOpInferrer().infer_out(selfTmp, other);

interface: diopiSub(ctx, out, selfTmp, other, alpha)

- schema: "div.Scalar(Tensor self, Scalar other) -> Tensor"
custom_code_at_the_beginning: |
Expand Down Expand Up @@ -1341,9 +1360,10 @@

- schema: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
custom_code_at_the_beginning: |
if (is_scalar_on_cpu(self)) {
return dipu_sub_scalar(other, self.item(), alpha);
}
auto out = nodispatch::empty_like(self);
// NOLINTNEXTLINE(readability-suspicious-call-argument)
return dipu_sub_out(other, self, alpha, out);
interface: diopiSub(ctx, out, other, self, alpha)

- schema: "unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor out, Tensor indices, Tensor counts)"
Expand Down
11 changes: 9 additions & 2 deletions dipu/tests/python/unittests/test_rsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,15 @@ def test_rsub(self):
self._test_rsub(torch.ones(4, 5) * 1.1, torch.ones(4, 5) * 5, alpha=4)

def test_rsub_scalar(self):
self._test_rsub_scalar(torch.ones(4, 5), 10)
self._test_rsub_scalar(torch.ones(4, 5), 10, 2.5)
# from torch:
# For integral input tensors, argument alpha must not be a floating point number
# Boolean alpha only supported for Boolean results
self._test_rsub_scalar(torch.ones(4, 5), 10, alpha=1)
self.assertRaisesRegex(
RuntimeError,
r"For integral input tensors, argument alpha must not be a floating point number\.",
lambda: self._test_rsub_scalar(torch.ones(4, 5), 10, 2.5),
)


if __name__ == "__main__":
Expand Down