Skip to content

Commit f32a4e6

Browse files
author
Hariprasad Ravishankar
committed
[TOSA] Notify failure during torch-to-tosa when AtenEmptyMemoryFormat receives a tensor with a dimension of size zero
1 parent 7e1cd8b commit f32a4e6

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6435,6 +6435,11 @@ class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
64356435
for (auto s : shape)
64366436
size *= s;
64376437

6438+
if (size == 0) {
6439+
return rewriter.notifyMatchFailure(
6440+
op, "Shape must not have a dimension of size zero");
6441+
}
6442+
64386443
SmallVector<int32_t> values(size, fillVal);
64396444
auto constOp =
64406445
tosa::getConstTensor<int32_t>(rewriter, op, values, shape).value();

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4370,3 +4370,17 @@ func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch
43704370
%0 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[2,4],f16>, !torch.vtensor<[3,4],f16>, !torch.vtensor<[3],f16> -> !torch.vtensor<[2,3],f16>
43714371
return %0 : !torch.vtensor<[2,3],f16>
43724372
}
4373+
4374+
// -----
4375+
func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
4376+
%c1 = torch.constant.int 1
4377+
%c0 = torch.constant.int 0
4378+
%c256 = torch.constant.int 256
4379+
%2452 = torch.prim.ListConstruct %c1, %c0, %c256 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
4380+
%none = torch.constant.none
4381+
%cpu = torch.constant.device "cpu"
4382+
%false = torch.constant.bool false
4383+
// expected-error @below {{failed to legalize operation 'torch.aten.empty.memory_format' that was explicitly marked illegal}}
4384+
%out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32>
4385+
return %out : !torch.vtensor<[1,0,256],f32>
4386+
}

0 commit comments

Comments
 (0)